fix: fix CohereForAI/c4ai-command-r-plus (#1707)

@Narsil @drbh this will update flash attention v2 and vllm.
You will need to re-install them.
This commit is contained in:
OlivierDehaene 2024-04-10 17:20:25 +02:00 committed by GitHub
parent 4634b00c2a
commit ad9d6288c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 281 additions and 620 deletions

View File

@ -85,7 +85,7 @@ FROM pytorch-install as kernel-builder
ARG MAX_JOBS=8
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build \
ninja-build cmake \
&& rm -rf /var/lib/apt/lists/*
# Build Flash Attention CUDA kernels
@ -160,11 +160,6 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN make build-all
# Build megablocks
FROM kernel-builder as megablocks-builder
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
@ -186,8 +181,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch and Megablocks installed
COPY --from=megablocks-builder /opt/conda /opt/conda
# Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@ -215,7 +210,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
# Install vllm/flash-attention dependencies
RUN pip install einops --no-cache-dir
# Install server

View File

@ -499,6 +499,9 @@ fn shard_manager(
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Remove LOG_LEVEL if present
envs.retain(|(name, _)| name != "LOG_LEVEL");
// Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -586,6 +589,7 @@ fn shard_manager(
tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server")
.args(shard_args)
.env_clear()
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
@ -824,6 +828,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Remove LOG_LEVEL if present
envs.retain(|(name, _)| name != "LOG_LEVEL");
// Disable progress bar
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
@ -858,6 +865,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
tracing::info!("Starting download process.");
let mut download_process = match Command::new("text-generation-server")
.args(download_args)
.env_clear()
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())

View File

@ -1,8 +1,8 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::{
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
Message, PrefillToken, Queue, Token,
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
@ -86,7 +86,18 @@ impl Infer {
let chat_template = tokenizer_config
.chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
.and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| {
// .strip() is not supported in minijinja
let t = t.replace(".strip()", " | trim");
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -1099,7 +1110,7 @@ mod tests {
ChatTemplateTestItem {
name: "_base",
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
input: ChatTemplateInputs{
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
@ -1110,7 +1121,7 @@ mod tests {
ChatTemplateTestItem {
name: "blenderbot",
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
input: ChatTemplateInputs{
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
@ -1121,7 +1132,7 @@ mod tests {
ChatTemplateTestItem {
name: "blenderbot_small",
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
input: ChatTemplateInputs{
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
@ -1132,7 +1143,7 @@ mod tests {
ChatTemplateTestItem {
name: "bloom",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
@ -1143,7 +1154,7 @@ mod tests {
ChatTemplateTestItem {
name: "gpt_neox",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
@ -1154,38 +1165,37 @@ mod tests {
ChatTemplateTestItem {
name: "gpt2",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
},
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
},
ChatTemplateTestItem {
name: "llama",
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
input: ChatTemplateInputs{
messages: example_chat_with_system.clone(),
add_generation_prompt: true,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
input: ChatTemplateInputs {
messages: example_chat_with_system.clone(),
add_generation_prompt: true,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
},
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
},
ChatTemplateTestItem {
name: "whisper",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{
messages: example_chat.clone(),
add_generation_prompt: true,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
input: ChatTemplateInputs {
messages: example_chat.clone(),
add_generation_prompt: true,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
},
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
}
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
},
];
#[allow(unused_variables)] // name is unused
@ -1211,7 +1221,7 @@ mod tests {
messages: example_chat_with_system.clone(),
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("</s>")
eos_token: Some("</s>"),
},
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
},
@ -1237,7 +1247,7 @@ mod tests {
bos_token: Some(""),
eos_token: Some("</s>"),
},
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>"
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
},
ChatTemplateTestItem {
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
@ -1259,7 +1269,7 @@ mod tests {
bos_token: Some("<s>"),
eos_token: Some("</s>"),
},
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
},
ChatTemplateTestItem {
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
@ -1276,7 +1286,7 @@ mod tests {
name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
input: ChatTemplateInputs {
messages: example_chat.clone(),
messages: example_chat.clone(),
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
@ -1360,7 +1370,7 @@ mod tests {
bos_token: Some("<s>"),
eos_token: Some("</s>"),
},
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>"
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
},
ChatTemplateTestItem {
name: "internlm/internlm2-chat-7b",
@ -1443,7 +1453,7 @@ mod tests {
eos_token: Some("</s>"),
},
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
}
},
];
#[allow(unused_variables)] // name is unused

View File

@ -49,9 +49,22 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>,
}
#[derive(Clone, Deserialize, Default)]
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ChatTemplate {
name: String,
template: String,
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ChatTemplateVersions {
Single(String),
Multiple(Vec<ChatTemplate>),
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubTokenizerConfig {
pub chat_template: Option<String>,
pub chat_template: Option<ChatTemplateVersions>,
pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>,
@ -978,7 +991,10 @@ mod tests {
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string()));
assert_eq!(
config.chat_template,
Some(ChatTemplateVersions::Single("test".to_string()))
);
assert_eq!(
config.bos_token,
Some("<begin▁of▁sentence>".to_string())
@ -1010,7 +1026,10 @@ mod tests {
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string()));
assert_eq!(
config.chat_template,
Some(ChatTemplateVersions::Single("test".to_string()))
);
assert_eq!(
config.bos_token,
Some("<begin▁of▁sentence>".to_string())

View File

@ -17,9 +17,6 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
install: gen-server
pip install pip --upgrade
pip install -r requirements_cuda.txt

View File

@ -1,4 +1,4 @@
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69

View File

@ -1,10 +1,10 @@
vllm-cuda:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/vllm-project/vllm.git vllm
git clone https://github.com/OlivierDehaene/vllm.git vllm
build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
cd vllm && python setup.py build
install-vllm-cuda: build-vllm-cuda

View File

@ -43,7 +43,7 @@ class CacheManager:
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32
0, num_blocks * self.block_size, dtype=torch.int64
).view(num_blocks, self.block_size)
def allocate(

View File

@ -23,10 +23,10 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -34,65 +34,106 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding,
SpeculativeHead,
get_linear,
FastRMSNorm,
FastLayerNorm,
)
if IS_CUDA_SYSTEM:
import dropout_layer_norm
else:
dropout_layer_norm = None
class CohereConfig(PretrainedConfig):
def __init__(
class CohereRotary(PositionRotaryEmbedding):
def forward(
self,
vocab_size=256000,
hidden_size=8192,
intermediate_size=22528,
num_hidden_layers=40,
num_attention_heads=64,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=5,
eos_token_id=255001,
pretraining_tp=1,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
logit_scale=1.0,
**kwargs,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# Such controlflows may add some overhead.
if IS_CUDA_SYSTEM:
import rotary_emb
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
q1 = query[..., ::2]
q2 = query[..., 1::2]
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.logit_scale = logit_scale
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
k1 = key[..., ::2]
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class CohereLayerNorm(nn.Module):
def __init__(self, prefix, weights, eps):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
self.weight = nn.Parameter(weight)
# Fake weights
self.ones = weight.new_ones(weight.shape[1])
self.eps = eps
def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
(
hidden_states,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.ones,
None,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
# Required to apply one weight matrix per head
hidden_states = hidden_states.view(
-1, self.weight.shape[0], self.weight.shape[1]
)
hidden_states = self.weight * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
@ -154,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
self.rotary_emb = CohereRotary.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
@ -175,6 +216,22 @@ class FlashCohereAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = CohereLayerNorm(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.k_norm = CohereLayerNorm(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.layer_norm_eps,
)
else:
self.q_norm = None
self.k_norm = None
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
@ -199,21 +256,28 @@ class FlashCohereAttention(torch.nn.Module):
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
query, key, value = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=1,
)
if self.use_qk_norm:
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query = self.q_norm(query.contiguous())
key = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(query, key, cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -223,8 +287,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
@ -298,7 +362,7 @@ class FlashCohereLayer(nn.Module):
)
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
self.input_layernorm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
@ -362,7 +426,7 @@ class FlashCohereModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
)

View File

@ -16,14 +16,13 @@
import torch
import torch.distributed
import numpy as np
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
FastLinear,
@ -37,14 +36,6 @@ from text_generation_server.utils.layers import (
)
from text_generation_server.utils.log import log_once
HAS_MEGABLOCKS = True
try:
import stk
import megablocks.ops as ops
except ImportError:
logger.warning("Dbrx: megablocks is not installed")
HAS_MEGABLOCKS = False
class DbrxAttentionConfig(PretrainedConfig):
def __init__(
@ -531,18 +522,6 @@ def round_up(x: torch.Tensor, value: int):
class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__()
self.moe_normalize_expert_weights = (
@ -572,241 +551,40 @@ class BlockSparseMoE(nn.Module):
)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights)
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights)
self.offsets = None
self.offsets_block_rows = 0
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.wv1 = torch.cat([w1, v1], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(
gate_logits, self.top_k, self.moe_normalize_expert_weights
)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and v1,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.v1.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
self.wv1,
self.w2,
router_logits,
self.top_k,
self.quantize_scatter_num_bits,
).view(*input_shape)
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
weights,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
weights.scatter_(1, not_selected_experts, 0)
# Re-normalize
if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
return out.view(*x.shape)
class DenseMoE(nn.Module):

View File

@ -24,6 +24,7 @@ import torch.distributed
import numpy as np
from torch import nn
from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
@ -41,14 +42,6 @@ from text_generation_server.utils.layers import (
get_linear,
)
HAS_MEGABLOCKS = True
try:
import stk
import megablocks.ops as ops
except ImportError:
logger.warning("Mixtral: megablocks is not installed")
HAS_MEGABLOCKS = False
class MixtralConfig(PretrainedConfig):
model_type = "mixtral"
@ -321,18 +314,6 @@ def round_up(x: torch.Tensor, value: int):
class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
@ -357,236 +338,40 @@ class BlockSparseMoE(nn.Module):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
self.offsets = None
self.offsets_block_rows = 0
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(
padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(gate_logits, self.top_k)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
* stk.ops.sdd(x, self.w3.t(), topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
self.w13,
self.w2,
router_logits,
self.top_k,
self.quantize_scatter_num_bits,
).view(*input_shape)
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Expand to [num_experts, sequence_length, model_dim]
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
# Permute to [num_experts, model_dim, ffn_dim]
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
renormalize=True,
inplace=True,
)
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
0, 2, 1
)
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
out = torch.bmm(
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
)
# Mask not selected experts
out *= weights.t().view(self.num_experts, -1, 1)
# Sum experts
out = out.sum(0)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x)
# This is faster when there is not a lot of tokens
return self.dense_forward(x)
return out.view(*x.shape)
class DenseMoE(nn.Module):

View File

@ -169,6 +169,11 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
input_length = len(tokenized_input)
input_lengths.append(input_length)
@ -694,7 +699,7 @@ class FlashCausalLM(Model):
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)

View File

@ -3,12 +3,11 @@ import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
CohereConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
@ -32,7 +31,7 @@ class FlashCohere(FlashCausalLM):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashCohere is only available on GPU")
@ -46,7 +45,7 @@ class FlashCohere(FlashCausalLM):
from_slow=False,
)
config = CohereConfig.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize

View File

@ -385,7 +385,7 @@ class BaseFlashMistral(FlashCausalLM):
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)

View File

@ -88,6 +88,9 @@ def attention(
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,

View File

@ -19,7 +19,6 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True
try:
@ -35,12 +34,6 @@ except Exception:
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
# V2 = False
# log_once(
# logger.warning,
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
# )
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
@ -174,6 +167,8 @@ class EETQLinear(nn.Module):
) -> None:
super().__init__()
device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False)

View File

@ -1,8 +1,7 @@
import torch
# vllm imports
from vllm import cache_ops
from vllm import attention_ops
from vllm._C import cache_ops, ops
_PARTITION_SIZE = 512
@ -14,7 +13,7 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def attention(
@ -54,9 +53,9 @@ def attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
attention_ops.paged_attention_v1(
ops.paged_attention_v1(
out,
query,
key_cache,
@ -68,6 +67,8 @@ def attention(
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
@ -83,7 +84,7 @@ def attention(
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2(
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
@ -98,4 +99,6 @@ def attention(
block_size,
max_s,
None,
"auto",
1.0,
)