From ad9d6288c843735df2c8a9f6c6289fed74268f38 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 10 Apr 2024 17:20:25 +0200 Subject: [PATCH] fix: fix CohereForAI/c4ai-command-r-plus (#1707) @Narsil @drbh this will update flash attention v2 and vllm. You will need to re-install them. --- Dockerfile | 13 +- launcher/src/main.rs | 8 + router/src/infer.rs | 78 ++--- router/src/lib.rs | 27 +- server/Makefile | 3 - server/Makefile-flash-att-v2 | 2 +- server/Makefile-vllm | 4 +- .../models/cache_manager.py | 2 +- .../custom_modeling/flash_cohere_modeling.py | 188 ++++++++---- .../custom_modeling/flash_dbrx_modeling.py | 270 ++---------------- .../custom_modeling/flash_mixtral_modeling.py | 263 ++--------------- .../models/flash_causal_lm.py | 7 +- .../models/flash_cohere.py | 7 +- .../models/flash_mistral.py | 2 +- .../utils/flash_attn.py | 3 + server/text_generation_server/utils/layers.py | 9 +- .../utils/paged_attention.py | 15 +- 17 files changed, 281 insertions(+), 620 deletions(-) diff --git a/Dockerfile b/Dockerfile index e79372a3..0bc5f8d9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -85,7 +85,7 @@ FROM pytorch-install as kernel-builder ARG MAX_JOBS=8 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - ninja-build \ + ninja-build cmake \ && rm -rf /var/lib/apt/lists/* # Build Flash Attention CUDA kernels @@ -160,11 +160,6 @@ WORKDIR /usr/src COPY server/Makefile-selective-scan Makefile RUN make build-all -# Build megablocks -FROM kernel-builder as megablocks-builder - -RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e - # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base @@ -186,8 +181,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ && rm -rf /var/lib/apt/lists/* -# Copy conda with PyTorch and Megablocks installed -COPY --from=megablocks-builder /opt/conda /opt/conda +# Copy conda with PyTorch installed +COPY --from=pytorch-install /opt/conda /opt/conda # Copy build artifacts from flash attention builder COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages @@ -215,7 +210,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages -# Install flash-attention dependencies +# Install vllm/flash-attention dependencies RUN pip install einops --no-cache-dir # Install server diff --git a/launcher/src/main.rs b/launcher/src/main.rs index aef09433..3f810023 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -499,6 +499,9 @@ fn shard_manager( // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Remove LOG_LEVEL if present + envs.retain(|(name, _)| name != "LOG_LEVEL"); + // Torch Distributed Env vars envs.push(("RANK".into(), rank.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -586,6 +589,7 @@ fn shard_manager( tracing::info!("Starting shard"); let mut p = match Command::new("text-generation-server") .args(shard_args) + .env_clear() .envs(envs) .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -824,6 +828,9 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Remove LOG_LEVEL if present + envs.retain(|(name, _)| name != "LOG_LEVEL"); + // Disable progress bar envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); @@ -858,6 +865,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L tracing::info!("Starting download process."); let mut download_process = match Command::new("text-generation-server") .args(download_args) + .env_clear() .envs(envs) .stdout(Stdio::piped()) .stderr(Stdio::piped()) diff --git a/router/src/infer.rs b/router/src/infer.rs index e5517511..075e76d8 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,8 +1,8 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ - ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, - Message, PrefillToken, Queue, Token, + ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, + HubTokenizerConfig, Message, PrefillToken, Queue, Token, }; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; @@ -86,7 +86,18 @@ impl Infer { let chat_template = tokenizer_config .chat_template - .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); + .and_then(|t| match t { + ChatTemplateVersions::Single(template) => Some(template), + ChatTemplateVersions::Multiple(templates) => templates + .into_iter() + .find(|t| t.name == "default") + .map(|t| t.template), + }) + .map(|t| { + // .strip() is not supported in minijinja + let t = t.replace(".strip()", " | trim"); + ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) + }); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -1099,7 +1110,7 @@ mod tests { ChatTemplateTestItem { name: "_base", chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1110,7 +1121,7 @@ mod tests { ChatTemplateTestItem { name: "blenderbot", chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1121,7 +1132,7 @@ mod tests { ChatTemplateTestItem { name: "blenderbot_small", chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1132,7 +1143,7 @@ mod tests { ChatTemplateTestItem { name: "bloom", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1143,7 +1154,7 @@ mod tests { ChatTemplateTestItem { name: "gpt_neox", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1154,38 +1165,37 @@ mod tests { ChatTemplateTestItem { name: "gpt2", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", }, ChatTemplateTestItem { name: "llama", // NOTE: the `.strip()` has been replaced with `| trim` in the following template chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content | trim + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", - input: ChatTemplateInputs{ - messages: example_chat_with_system.clone(), - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some(""), + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some(""), }, - target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, ChatTemplateTestItem { name: "whisper", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ - messages: example_chat.clone(), - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - } - + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, ]; #[allow(unused_variables)] // name is unused @@ -1211,7 +1221,7 @@ mod tests { messages: example_chat_with_system.clone(), add_generation_prompt: false, bos_token: Some(""), - eos_token: Some("") + eos_token: Some(""), }, target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHello, how are you?<|assistant|>\nI'm doing great. How can I help you today?<|user|>\nI'd like to show off how chat templating works!", }, @@ -1237,7 +1247,7 @@ mod tests { bos_token: Some(""), eos_token: Some(""), }, - target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>" + target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>", }, ChatTemplateTestItem { name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", @@ -1259,7 +1269,7 @@ mod tests { bos_token: Some(""), eos_token: Some(""), }, - target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, ChatTemplateTestItem { name: "mistralai/Mixtral-8x7B-Instruct-v0.1", @@ -1276,7 +1286,7 @@ mod tests { name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", input: ChatTemplateInputs { - messages: example_chat.clone(), + messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), @@ -1360,7 +1370,7 @@ mod tests { bos_token: Some(""), eos_token: Some(""), }, - target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!" + target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!", }, ChatTemplateTestItem { name: "internlm/internlm2-chat-7b", @@ -1443,7 +1453,7 @@ mod tests { eos_token: Some(""), }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", - } + }, ]; #[allow(unused_variables)] // name is unused diff --git a/router/src/lib.rs b/router/src/lib.rs index c787470b..2e412f1a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -49,9 +49,22 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct ChatTemplate { + name: String, + template: String, +} + +#[derive(Debug, Clone, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum ChatTemplateVersions { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { - pub chat_template: Option, + pub chat_template: Option, pub completion_template: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option, @@ -978,7 +991,10 @@ mod tests { let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); // check that we successfully parsed the tokens - assert_eq!(config.chat_template, Some("test".to_string())); + assert_eq!( + config.chat_template, + Some(ChatTemplateVersions::Single("test".to_string())) + ); assert_eq!( config.bos_token, Some("<|begin▁of▁sentence|>".to_string()) @@ -1010,7 +1026,10 @@ mod tests { let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); // check that we successfully parsed the tokens - assert_eq!(config.chat_template, Some("test".to_string())); + assert_eq!( + config.chat_template, + Some(ChatTemplateVersions::Single("test".to_string())) + ); assert_eq!( config.bos_token, Some("<|begin▁of▁sentence|>".to_string()) diff --git a/server/Makefile b/server/Makefile index da5171b2..32d01709 100644 --- a/server/Makefile +++ b/server/Makefile @@ -17,9 +17,6 @@ gen-server: find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py -install-megablocks: - pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e - install: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 71c6cabe..803b3d1f 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3 +flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69 diff --git a/server/Makefile-vllm b/server/Makefile-vllm index c9c1d520..ada484a6 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,10 +1,10 @@ vllm-cuda: # Clone vllm pip install -U ninja packaging --no-cache-dir - git clone https://github.com/vllm-project/vllm.git vllm + git clone https://github.com/OlivierDehaene/vllm.git vllm build-vllm-cuda: vllm-cuda - cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc + cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2 cd vllm && python setup.py build install-vllm-cuda: build-vllm-cuda diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 2e6ae086..4be8b1b9 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -43,7 +43,7 @@ class CacheManager: ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int32 + 0, num_blocks * self.block_size, dtype=torch.int64 ).view(num_blocks, self.block_size) def allocate( diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 985bbd8e..56d9a966 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -23,10 +23,10 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -34,65 +34,106 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, SpeculativeHead, get_linear, - FastRMSNorm, + FastLayerNorm, ) +if IS_CUDA_SYSTEM: + import dropout_layer_norm +else: + dropout_layer_norm = None -class CohereConfig(PretrainedConfig): - def __init__( + +class CohereRotary(PositionRotaryEmbedding): + def forward( self, - vocab_size=256000, - hidden_size=8192, - intermediate_size=22528, - num_hidden_layers=40, - num_attention_heads=64, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=8192, - initializer_range=0.02, - layer_norm_eps=1e-5, - use_cache=True, - pad_token_id=0, - bos_token_id=5, - eos_token_id=255001, - pretraining_tp=1, - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - logit_scale=1.0, - **kwargs, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads + # Such controlflows may add some overhead. + if IS_CUDA_SYSTEM: + import rotary_emb - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads + q1 = query[..., ::2] + q2 = query[..., 1::2] - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.logit_scale = logit_scale + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, + k1 = key[..., ::2] + k2 = key[..., 1::2] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif IS_ROCM_SYSTEM: + from vllm import pos_encoding_ops + + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + +class CohereLayerNorm(nn.Module): + def __init__(self, prefix, weights, eps): + super().__init__() + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + self.weight = nn.Parameter(weight) + # Fake weights + self.ones = weight.new_ones(weight.shape[1]) + self.eps = eps + + def forward(self, hidden_states): + if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + hidden_states = hidden_states.reshape( + -1, self.weight.shape[0], self.weight.shape[1] + ) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + hidden_states_minus_mean = hidden_states - mean + variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) + hidden_states = self.weight.to(torch.float32) * hidden_states + hidden_states = hidden_states.view(-1, self.weight.shape[1]) + return hidden_states.to(input_dtype) + + ( + hidden_states, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.ones, + None, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, ) + # Required to apply one weight matrix per head + hidden_states = hidden_states.view( + -1, self.weight.shape[0], self.weight.shape[1] + ) + hidden_states = self.weight * hidden_states + hidden_states = hidden_states.view(-1, self.weight.shape[1]) + + return hidden_states + def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: @@ -154,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static( + self.rotary_emb = CohereRotary.static( config=config, dim=self.head_size, base=config.rope_theta, @@ -175,6 +216,22 @@ class FlashCohereAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = CohereLayerNorm( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.k_norm = CohereLayerNorm( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.layer_norm_eps, + ) + else: + self.q_norm = None + self.k_norm = None + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -199,21 +256,28 @@ class FlashCohereAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) - query, kv = qkv.split( + query, key, value = qkv.split( [ self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, ], dim=1, ) + + if self.use_qk_norm: + query = query.reshape(-1, self.head_size) + key = key.reshape(-1, self.head_size) + query = self.q_norm(query.contiguous()) + key = self.k_norm(key.contiguous()) + query = query.view(-1, self.num_heads, self.head_size) - kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + key = key.view(-1, self.num_key_value_heads, self.head_size) + value = value.view(-1, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + self.rotary_emb(query, key, cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) # output tensor attn_output = torch.empty_like(query) @@ -223,8 +287,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention flash_attn.attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + key, + value, attn_output, cu_seqlen_prefill, max_s, @@ -298,7 +362,7 @@ class FlashCohereLayer(nn.Module): ) self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = FastRMSNorm.load( + self.input_layernorm = FastLayerNorm.load_no_bias( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps, @@ -362,7 +426,7 @@ class FlashCohereModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = FastRMSNorm.load( + self.norm = FastLayerNorm.load_no_bias( prefix="model.norm", weights=weights, eps=config.layer_norm_eps ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index dd0bcca5..d04ce39e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -16,14 +16,13 @@ import torch import torch.distributed -import numpy as np - from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger +from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( FastLinear, @@ -37,14 +36,6 @@ from text_generation_server.utils.layers import ( ) from text_generation_server.utils.log import log_once -HAS_MEGABLOCKS = True -try: - import stk - import megablocks.ops as ops -except ImportError: - logger.warning("Dbrx: megablocks is not installed") - HAS_MEGABLOCKS = False - class DbrxAttentionConfig(PretrainedConfig): def __init__( @@ -531,18 +522,6 @@ def round_up(x: torch.Tensor, value: int): class BlockSparseMoE(nn.Module): - """ - Built on the paper and library Megablocks as described in - https://arxiv.org/abs/2211.15841. This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accomodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - def __init__(self, prefix, config: DbrxConfig, weights): super().__init__() self.moe_normalize_expert_weights = ( @@ -572,241 +551,40 @@ class BlockSparseMoE(nn.Module): ) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights) - self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights) - self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights) - - self.offsets = None - self.offsets_block_rows = 0 + w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.wv1 = torch.cat([w1, v1], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts.mlp.w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) self.process_group = weights.process_group - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - self.blocking = 128 - self.quantize_scatter_num_bits = -1 - - def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - assert self.ffn_dim % self.blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_dim // self.blocking - if self.offsets is None or block_rows > self.offsets_block_rows: - self.offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - self.offsets_block_rows = block_rows - offsets = self.offsets - else: - offsets = self.offsets[: block_rows + 1] - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, self.blocking, block_rows, blocks_per_row - ) - - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=x.dtype, - device="meta", - ) - shape = (padded_tokens, self.ffn_dim * self.num_experts) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - False, - False, - False, - ) - - def indices_and_padded_bins(self, selected_experts: torch.Tensor): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # selected_experts = selected_experts.int() - - # returns bin_ids == num of experts for this sequence ? == unique selected experts? - # and indices == how to sort tokens? - bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) - # bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k] - # indices => [14, 32, 33, ...] => [num_tokens * top_k] - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(selected_experts, self.num_experts) - # tokens_per_expert => [3, 0, 2, ...] => [num_experts] - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - - # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) - # padded_tokens_per_expert => [128, O, 128, ...] - - # Cumulative selected experts per token - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - # padded_bins => [128, 128, 256, ...] - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - # bins => [3, 3, 5, ...] - - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - selected_experts, weights = select_experts( - gate_logits, self.top_k, self.moe_normalize_expert_weights - ) - - ( - indices, - bin_ids, - bins, - padded_bins, - _, - ) = self.indices_and_padded_bins(selected_experts) - - # Permute tokens and pad to prepare expert computation - # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) - - # Create the sparse matrix topology - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation - # First Dense x Dense -> Sparse for w1 and v1, - # (top_k * sequence_length + padding, ffn_dim * n_experts) - x = stk.Matrix( - topo.size(), - self.act(stk.ops.sdd(x, self.w1.t(), topo).data) - * stk.ops.sdd(x, self.v1.t(), topo).data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Then Sparse x Dense -> Dense for w2 - # (top_k * sequence_length + padding, model_dim) - x = stk.ops.dsd(x, self.w2) - - # Permute back and remove padding - # (sequence_length, model_dim) - x = ops.padded_scatter( + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( x, - indices, - bin_ids, - weights, - bins, - padded_bins, + self.wv1, + self.w2, + router_logits, self.top_k, - self.quantize_scatter_num_bits, - ).view(*input_shape) - - if self.process_group.size() > 1: - torch.distributed.all_reduce(x, group=self.process_group) - - return x.view(*input_shape) - - def dense_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) - - if self.top_k < self.num_experts: - _, not_selected_experts = torch.topk( - weights, - self.num_experts - self.top_k, - largest=False, - sorted=False, - dim=1, - ) - # Mask not selected experts - weights.scatter_(1, not_selected_experts, 0) - - # Re-normalize - if self.moe_normalize_expert_weights: - weights = weights / torch.norm( - weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True - ) - weights = weights.to(x.dtype) - - # Expand to [num_experts, sequence_length, model_dim] - x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) - - # Permute to [num_experts, model_dim, ffn_dim] - w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 + renormalize=self.moe_normalize_expert_weights, + inplace=True, ) - v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 - ) - - inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1) - - out = torch.bmm( - inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim) - ) - # Mask not selected experts - out *= weights.t().view(self.num_experts, -1, 1) - - # Sum experts - out = out.sum(0) # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) - return out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if len(x) > 256 and HAS_MEGABLOCKS: - return self.sparse_forward(x) - # This is faster when there is not a lot of tokens - return self.dense_forward(x) + return out.view(*x.shape) class DenseMoE(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index d71a3f0c..89eb8f43 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,6 +24,7 @@ import torch.distributed import numpy as np from torch import nn +from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple @@ -41,14 +42,6 @@ from text_generation_server.utils.layers import ( get_linear, ) -HAS_MEGABLOCKS = True -try: - import stk - import megablocks.ops as ops -except ImportError: - logger.warning("Mixtral: megablocks is not installed") - HAS_MEGABLOCKS = False - class MixtralConfig(PretrainedConfig): model_type = "mixtral" @@ -321,18 +314,6 @@ def round_up(x: torch.Tensor, value: int): class BlockSparseMoE(nn.Module): - """ - Built on the paper and library Megablocks as described in - https://arxiv.org/abs/2211.15841. This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accomodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - def __init__(self, prefix, config: MixtralConfig, weights): super().__init__() self.hidden_dim = config.hidden_size @@ -357,236 +338,40 @@ class BlockSparseMoE(nn.Module): self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights) - self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) - self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights) - - self.offsets = None - self.offsets_block_rows = 0 + w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view( + self.num_experts, self.ffn_dim, self.hidden_dim + ) + self.w13 = torch.cat([w1, w3], dim=1) + self.w2 = ( + _load_experts(config, f"{prefix}.experts", "w2", weights) + .view(self.num_experts, self.ffn_dim, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) self.process_group = weights.process_group - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - self.blocking = 128 - self.quantize_scatter_num_bits = -1 - - def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - assert self.ffn_dim % self.blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_dim // self.blocking - if self.offsets is None or block_rows > self.offsets_block_rows: - self.offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - self.offsets_block_rows = block_rows - offsets = self.offsets - else: - offsets = self.offsets[: block_rows + 1] - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, self.blocking, block_rows, blocks_per_row - ) - - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=x.dtype, - device="meta", - ) - shape = (padded_tokens, self.ffn_dim * self.num_experts) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - False, - False, - False, - ) - - def indices_and_padded_bins(self, selected_experts: torch.Tensor): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # selected_experts = selected_experts.int() - - # returns bin_ids == num of experts for this sequence ? == unique selected experts? - # and indices == how to sort tokens? - bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) - # bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k] - # indices => [14, 32, 33, ...] => [num_tokens * top_k] - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(selected_experts, self.num_experts) - # tokens_per_expert => [3, 0, 2, ...] => [num_experts] - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - - # List of size num_experts - padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking) - # padded_tokens_per_expert => [128, O, 128, ...] - - # Cumulative selected experts per token - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - # padded_bins => [128, 128, 256, ...] - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - # bins => [3, 3, 5, ...] - - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - selected_experts, weights = select_experts(gate_logits, self.top_k) - - ( - indices, - bin_ids, - bins, - padded_bins, - _, - ) = self.indices_and_padded_bins(selected_experts) - - # Permute tokens and pad to prepare expert computation - # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) - - # Create the sparse matrix topology - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation - # First Dense x Dense -> Sparse for w1 and w3, - # (top_k * sequence_length + padding, ffn_dim * n_experts) - x = stk.Matrix( - topo.size(), - self.act(stk.ops.sdd(x, self.w1.t(), topo).data) - * stk.ops.sdd(x, self.w3.t(), topo).data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Then Sparse x Dense -> Dense for w2 - # (top_k * sequence_length + padding, model_dim) - x = stk.ops.dsd(x, self.w2) - - # Permute back and remove padding - # (sequence_length, model_dim) - x = ops.padded_scatter( + def forward(self, x: torch.Tensor) -> torch.Tensor: + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(x) + out = fused_moe( x, - indices, - bin_ids, - weights, - bins, - padded_bins, + self.w13, + self.w2, + router_logits, self.top_k, - self.quantize_scatter_num_bits, - ).view(*input_shape) - - if self.process_group.size() > 1: - torch.distributed.all_reduce(x, group=self.process_group) - - return x.view(*input_shape) - - def dense_forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) - - if self.top_k < self.num_experts: - _, not_selected_experts = torch.topk( - all_probs, - self.num_experts - self.top_k, - largest=False, - sorted=False, - dim=1, - ) - # Mask not selected experts - all_probs.scatter_(1, not_selected_experts, 0) - - # Re-normalize - weights = all_probs / all_probs.sum(dim=1, keepdim=True) - weights = weights.to(x.dtype) - - # Expand to [num_experts, sequence_length, model_dim] - x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1]) - - # Permute to [num_experts, model_dim, ffn_dim] - w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 + renormalize=True, + inplace=True, ) - w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute( - 0, 2, 1 - ) - - inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3) - - out = torch.bmm( - inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim) - ) - # Mask not selected experts - out *= weights.t().view(self.num_experts, -1, 1) - - # Sum experts - out = out.sum(0) # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) - return out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if len(x) > 256 and HAS_MEGABLOCKS: - return self.sparse_forward(x) - # This is faster when there is not a lot of tokens - return self.dense_forward(x) + return out.view(*x.shape) class DenseMoE(nn.Module): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index be513511..2a9d3914 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -169,6 +169,11 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping[r.id] = i tokenized_input = tokenized_input[-r.truncate :] + if ( + tokenized_input[0] == tokenizer.bos_token_id + and tokenized_input[1] == tokenizer.bos_token_id + ): + tokenized_input = tokenized_input[1:] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -694,7 +699,7 @@ class FlashCausalLM(Model): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s block_tables = ( torch.arange(max_bt, dtype=torch.int32, device=self.device) diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 181a93b1..f85c7722 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -3,12 +3,11 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoConfig from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( FlashCohereForCausalLM, - CohereConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -32,7 +31,7 @@ class FlashCohere(FlashCausalLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashCohere is only available on GPU") @@ -46,7 +45,7 @@ class FlashCohere(FlashCausalLM): from_slow=False, ) - config = CohereConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 575dbba0..ace7ea8e 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -385,7 +385,7 @@ class BaseFlashMistral(FlashCausalLM): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s block_tables = ( torch.arange(max_bt, dtype=torch.int32, device=self.device) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 48f8ef70..45090c64 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -88,6 +88,9 @@ def attention( out, cu_seqlens, cu_seqlens, + None, + None, + None, max_s, max_s, 0.0, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 209f1c8a..f29e55c5 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -19,7 +19,6 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM -from text_generation_server.utils.log import log_once HAS_AWQ = True try: @@ -35,12 +34,6 @@ except Exception: HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: -# V2 = False -# log_once( -# logger.warning, -# "Disabling exllama v2 and using v1 instead because there are issues when sharding", -# ) if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False @@ -174,6 +167,8 @@ class EETQLinear(nn.Module): ) -> None: super().__init__() device = weight.device + if weight.dtype != torch.float16: + weight = weight.to(dtype=torch.float16) weight = torch.t(weight).contiguous().cpu() weight, scale = quant_weights(weight, torch.int8, False) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 4b12744c..18e605b0 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,8 +1,7 @@ import torch # vllm imports -from vllm import cache_ops -from vllm import attention_ops +from vllm._C import cache_ops, ops _PARTITION_SIZE = 512 @@ -14,7 +13,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) def attention( @@ -54,9 +53,9 @@ def attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: - attention_ops.paged_attention_v1( + ops.paged_attention_v1( out, query, key_cache, @@ -68,6 +67,8 @@ def attention( block_size, max_s, None, + "auto", + 1.0, ) else: # Run PagedAttention V2. @@ -83,7 +84,7 @@ def attention( device=out.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( out, exp_sums, max_logits, @@ -98,4 +99,6 @@ def attention( block_size, max_s, None, + "auto", + 1.0, )