feat: support attention sinks

This commit is contained in:
OlivierDehaene 2023-10-05 14:40:35 +02:00
parent 3c373dcc53
commit cc36128cda
17 changed files with 383 additions and 351 deletions

View File

@ -333,6 +333,16 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
rope_factor: Option<f32>, rope_factor: Option<f32>,
/// Sliding Window will only be used by flash attention optimized models
/// Limit the Paged Attention context window size
#[clap(long, env)]
sliding_window: Option<u32>,
/// If `sliding_window` is set, always keep the first `attention_sinks` tokens in the context
/// See: [Efficient Streaming Language Models with Attention Sinks](https://arxiv.org/abs/2309.17453)
#[clap(long, env)]
attention_sinks: Option<u32>,
/// Outputs the logs in JSON format (useful for telemetry) /// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,
@ -390,6 +400,8 @@ fn shard_manager(
cuda_memory_fraction: f32, cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>, rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>, rope_factor: Option<f32>,
sliding_window: Option<u32>,
attention_sinks: Option<u32>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
@ -495,6 +507,17 @@ fn shard_manager(
envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
} }
// Detect sliding window
// Sending as env instead of CLI args to not bloat everything
// those only can be used by flash attention models, so passing information around
// for all models will complexify code unnecessarily
if let Some(sliding_window) = sliding_window {
envs.push(("SLIDING_WINDOW".into(), sliding_window.to_string().into()));
}
if let Some(attention_sinks) = attention_sinks {
envs.push(("ATTENTION_SINKS".into(), attention_sinks.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -891,6 +914,8 @@ fn spawn_shards(
let cuda_memory_fraction = args.cuda_memory_fraction; let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling; let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor; let rope_factor = args.rope_factor;
let sliding_window = args.sliding_window;
let attention_sinks = args.attention_sinks;
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
@ -911,6 +936,8 @@ fn spawn_shards(
cuda_memory_fraction, cuda_memory_fraction,
rope_scaling, rope_scaling,
rope_factor, rope_factor,
sliding_window,
attention_sinks,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,

View File

@ -297,7 +297,7 @@ def get_model(
raise ValueError("awq quantization is not supported for AutoModel") raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel") raise ValueError("4bit quantization is not supported for AutoModel")
elif (quantize == "eetq"): elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel") raise ValueError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(

View File

@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)

View File

@ -15,12 +15,14 @@ class CacheManager:
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
attention_sinks: int,
repeat_slots: bool, repeat_slots: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
): ):
self.block_size = BLOCK_SIZE self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.attention_sinks = attention_sinks
self.repeat_slots = repeat_slots self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
@ -82,8 +84,23 @@ class CacheManager:
# Repeat slots in the case of context sliding window # Repeat slots in the case of context sliding window
if needed_slots > len(all_slots) and self.repeat_slots: if needed_slots > len(all_slots) and self.repeat_slots:
repeats = math.ceil(needed_slots / len(all_slots)) repeats = math.ceil(
all_slots = all_slots.repeat(repeats) needed_slots / (len(all_slots) - self.attention_sinks)
)
if self.attention_sinks > 0:
# Remove attention sinks from the repeat to not override them
all_slots = torch.cat(
[
all_slots,
all_slots[self.attention_sinks :].repeat(repeats - 1),
]
)
else:
all_slots = all_slots.repeat(repeats)
elif needed_slots > len(all_slots):
raise RuntimeError("Out of available slots. This is a bug")
allocated_slots = all_slots[:needed_slots] allocated_slots = all_slots[:needed_slots]
@ -112,6 +129,7 @@ def set_cache_manager(
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
attention_sinks: int,
repeat_slots: bool, repeat_slots: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
@ -122,7 +140,14 @@ def set_cache_manager(
torch.cuda.empty_cache() torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager( CACHE_MANAGER = CacheManager(
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device num_blocks,
num_layers,
num_heads,
head_size,
attention_sinks,
repeat_slots,
dtype,
device,
) )
return CACHE_MANAGER return CACHE_MANAGER

View File

@ -254,6 +254,7 @@ class FlashLlamaAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -269,8 +270,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor # output tensor
@ -376,6 +382,7 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -390,6 +397,7 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
# faster post attention rms norm # faster post attention rms norm
@ -442,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -464,6 +473,7 @@ class FlashLlamaModel(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -492,8 +502,19 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
sliding_window: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif sliding_window != -1:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(sliding_window, max_s)
input_lengths = torch.clamp(input_lengths, max=sliding_window)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -503,6 +524,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -201,9 +201,6 @@ class MistralAttention(torch.nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
)
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
@ -252,6 +249,7 @@ class MistralAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
sliding_window,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -290,7 +288,7 @@ class MistralAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past, window_size_left=sliding_window,
) )
# Decode # Decode
else: else:
@ -381,6 +379,7 @@ class MistralLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
sliding_window,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -396,6 +395,7 @@ class MistralLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
sliding_window,
) )
# faster post attention rms norm # faster post attention rms norm
@ -449,6 +449,7 @@ class MistralModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
sliding_window: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -472,6 +473,7 @@ class MistralModel(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
sliding_window,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -489,9 +491,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
def forward( def forward(
self, self,
@ -504,16 +503,17 @@ class FlashMistralForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
sliding_window: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices] slots = slots[prefill_cache_indices]
else: elif sliding_window != -1:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
max_s = min(self.max_past, max_s) max_s = min(sliding_window, max_s)
input_lengths = torch.clamp(input_lengths, max=self.max_past) input_lengths = torch.clamp(input_lengths, max=sliding_window)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
@ -525,6 +525,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
sliding_window,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -133,6 +133,7 @@ class FlashNeoxAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -141,20 +142,28 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attention(
qkv[:, 0], query,
qkv[:, 1], kv[:, 0],
qkv[:, 2], kv[:, 1],
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
@ -166,7 +175,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_size = kv_cache[1].shape[3] block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention( vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
qkv[:, 0], query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
@ -245,6 +254,7 @@ class FlashNeoXLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states) ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -259,6 +269,7 @@ class FlashNeoXLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -283,6 +294,7 @@ class FlashNeoXLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -337,6 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
@ -359,6 +372,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)
@ -385,8 +399,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
sliding_window: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif sliding_window != -1:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(sliding_window, max_s)
input_lengths = torch.clamp(input_lengths, max=sliding_window)
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(
input_ids, input_ids,
position_ids, position_ids,
@ -396,6 +421,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -174,6 +174,7 @@ class FlashRWAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -191,8 +192,13 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output # output
@ -294,6 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -310,9 +317,14 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
kv[:, :, 0].contiguous(), kv_to_cache[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(), kv_to_cache[:, :, 1].contiguous(),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
slots, slots,
@ -428,6 +440,7 @@ class FlashRWLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -442,6 +455,7 @@ class FlashRWLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
@ -464,6 +478,7 @@ class FlashRWLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -513,6 +528,7 @@ class FlashRWLargeLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
ln_attn, residual = self.ln_attn(hidden_states, residual) ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual) ln_mlp, _ = self.ln_mlp(residual)
@ -528,6 +544,7 @@ class FlashRWLargeLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
# MLP. # MLP.
@ -589,6 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
@ -611,6 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -638,8 +657,19 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
sliding_window: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif sliding_window != -1:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(sliding_window, max_s)
input_lengths = torch.clamp(input_lengths, max=sliding_window)
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
@ -649,6 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -246,6 +246,7 @@ class FlashMQAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -258,8 +259,13 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
if prefill_cache_indices is not None:
kv_to_cache = key_value[prefill_cache_indices]
else:
kv_to_cache = key_value
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output # output
@ -367,6 +373,7 @@ class Block(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states, residual = self.ln_2(hidden_states, residual)
@ -420,6 +427,7 @@ class FlashSantacoderModel(nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -437,6 +445,7 @@ class FlashSantacoderModel(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -462,8 +471,19 @@ class FlashSantacoderForCausalLM(nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
sliding_window: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif sliding_window != -1:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(sliding_window, max_s)
input_lengths = torch.clamp(input_lengths, max=sliding_window)
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
position_ids, position_ids,
@ -473,6 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = image_url_or_urls image = image_url_or_urls
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)) response = requests.get(
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
)
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content
else: else:

View File

@ -9,7 +9,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -24,6 +24,10 @@ from text_generation_server.models.cache_manager import (
set_cache_manager, set_cache_manager,
BLOCK_SIZE, BLOCK_SIZE,
) )
from text_generation_server.models.sliding_window import (
set_sliding_window_from_env,
get_sliding_window,
)
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
@ -47,6 +51,12 @@ class FlashCausalLMBatch(Batch):
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor] cu_seqlen_prefill: Optional[torch.Tensor]
# Sliding window values
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Paged Attention values # Paged Attention values
# Set when creating the batch # Set when creating the batch
@ -109,6 +119,8 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window = get_sliding_window()
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
for r in pb.requests: for r in pb.requests:
@ -124,6 +136,7 @@ class FlashCausalLMBatch(Batch):
needed_blocks_slots = [] needed_blocks_slots = []
start_slots = [] start_slots = []
slot_indices = [] slot_indices = []
prefill_cache_indices = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -187,8 +200,15 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1 total_tokens = input_length + max_new_tokens - 1
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
# If using sliding window
if sliding_window is not None:
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min(needed_blocks, sliding_window.blocks)
blocks += needed_blocks blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens)) needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length) start_slots.append(cumulative_max_length)
@ -199,6 +219,32 @@ class FlashCausalLMBatch(Batch):
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
# If using sliding window
if sliding_window is not None:
# Start of the sliding window cache
start_offset = max(
0,
input_length - sliding_window.size + sliding_window.attention_sinks,
)
if sliding_window.attention_sinks > 0 and start_offset > 0:
# Attention sinks indices
request_attention_sinks_cache_indices = torch.arange(
cumulative_length,
cumulative_length
+ min(sliding_window.attention_sinks, start_offset),
dtype=torch.int64,
)
prefill_cache_indices.append(request_attention_sinks_cache_indices)
# Create tensor to slice into the kv tensor in prefill
request_prefill_cache_indices = torch.arange(
cumulative_length + start_offset,
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -252,12 +298,26 @@ class FlashCausalLMBatch(Batch):
position_ids = position_ids[0] position_ids = position_ids[0]
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
if len(prefill_cache_indices) > 1:
prefill_cache_indices = (
torch.cat(prefill_cache_indices) if prefill_cache_indices else None
)
else:
prefill_cache_indices = (
prefill_cache_indices[0] if prefill_cache_indices else None
)
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32 cu_seqlen_prefill, device=device, dtype=torch.int32
) )
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device)
if prefill_cache_indices is not None
else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor( input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device input_lengths, dtype=torch.int32, device=device
@ -309,6 +369,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -425,7 +486,7 @@ class FlashCausalLMBatch(Batch):
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
return type(self)( return FlashCausalLMBatch(
batch_id=self.batch_id, batch_id=self.batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
@ -454,6 +515,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=None,
) )
@classmethod @classmethod
@ -611,6 +673,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=None,
) )
def __del__(self): def __del__(self):
@ -636,11 +699,11 @@ class FlashCausalLM(Model):
device: torch.device, device: torch.device,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None,
): ):
self.num_layers = num_layers self.num_layers = num_layers
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.head_size = head_size self.head_size = head_size
set_sliding_window_from_env()
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
@ -650,7 +713,6 @@ class FlashCausalLM(Model):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=sliding_window,
) )
@property @property
@ -658,6 +720,8 @@ class FlashCausalLM(Model):
return FlashCausalLMBatch return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
sliding_window = get_sliding_window()
torch.cuda.empty_cache() torch.cuda.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
@ -665,7 +729,8 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.sliding_window is not None, sliding_window.attention_sinks if sliding_window is not None else 0,
True if sliding_window is not None else False,
self.dtype, self.dtype,
self.device, self.device,
) )
@ -705,7 +770,8 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.sliding_window is not None, sliding_window.attention_sinks if sliding_window is not None else 0,
True if sliding_window is not None else False,
self.dtype, self.dtype,
self.device, self.device,
) )
@ -713,8 +779,10 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
sliding_window = get_sliding_window()
# Model Forward # Model Forward
return self.model.forward( logits = self.model.forward(
input_ids=batch.input_ids, input_ids=batch.input_ids,
position_ids=batch.position_ids, position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill, cu_seqlen_prefill=batch.cu_seqlen_prefill,
@ -723,8 +791,13 @@ class FlashCausalLM(Model):
slots=batch.slots[batch.slot_indices], slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor, input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen, max_s=batch.max_seqlen,
prefill_cache_indices=batch.prefill_cache_indices,
sliding_window=sliding_window.size if sliding_window is not None else -1,
lm_head_indices=batch.prefill_head_indices, lm_head_indices=batch.prefill_head_indices,
) )
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(

View File

@ -1,21 +1,14 @@
import math
import torch import torch
import torch.distributed import torch.distributed
import numpy as np
from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.sliding_window import (
from text_generation_server.models.cache_manager import ( set_sliding_window,
get_cache_manager, get_sliding_window,
set_cache_manager,
) )
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,
@ -25,255 +18,10 @@ from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
HeterogeneousNextTokenChooser,
StoppingCriteria,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor] = None
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
prefill_cache_indices = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min(
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = prefill_cache_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
)
class FlashMistral(FlashCausalLM): class FlashMistral(FlashCausalLM):
def __init__( def __init__(
@ -284,9 +32,6 @@ class FlashMistral(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
@ -308,8 +53,7 @@ class FlashMistral(FlashCausalLM):
config.quantize = quantize config.quantize = quantize
# Set context windows # Set context windows
SLIDING_WINDOW = config.sliding_window set_sliding_window(config.sliding_window, 0)
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -331,27 +75,4 @@ class FlashMistral(FlashCausalLM):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=config.sliding_window,
) )
@property
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
logits = self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=batch.prefill_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits

View File

@ -3,10 +3,11 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.models.sliding_window import get_sliding_window
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
@ -21,7 +22,6 @@ class Model(ABC):
device: torch.device, device: torch.device,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None,
): ):
self.model = model.eval() self.model = model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -31,7 +31,6 @@ class Model(ABC):
self.device = device self.device = device
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.sliding_window = sliding_window
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
@ -42,14 +41,15 @@ class Model(ABC):
@property @property
def info(self) -> InfoResponse: def info(self) -> InfoResponse:
if self.requires_padding and self.sliding_window is not None: sliding_window = get_sliding_window()
if self.requires_padding and sliding_window is not None:
raise NotImplementedError("sliding_window is not implemented with padding") raise NotImplementedError("sliding_window is not implemented with padding")
return InfoResponse( return InfoResponse(
requires_padding=self.requires_padding, requires_padding=self.requires_padding,
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=sliding_window.size if sliding_window is not None else None,
) )
@property @property

View File

@ -0,0 +1,41 @@
import os
import math
from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
SLIDING_WINDOW: Optional["SlidingWindow"] = None
class SlidingWindow:
def __init__(self, size: int, attention_sinks: int):
self.size = size
self.blocks = math.ceil(size / BLOCK_SIZE)
self.attention_sinks = attention_sinks
@classmethod
def from_env(cls) -> Optional["SlidingWindow"]:
sliding_window_env = os.getenv("SLIDING_WINDOW", None)
if sliding_window_env is not None:
return cls(int(sliding_window_env), int(os.getenv("ATTENTION_SINKS", 0)))
return None
def set_sliding_window(size: int, attention_sinks: int) -> SlidingWindow:
global SLIDING_WINDOW
SLIDING_WINDOW = SlidingWindow(size, attention_sinks)
return SLIDING_WINDOW
def set_sliding_window_from_env() -> Optional[SlidingWindow]:
global SLIDING_WINDOW
env_sliding_window = SlidingWindow.from_env()
if env_sliding_window is not None:
SLIDING_WINDOW = env_sliding_window
return SLIDING_WINDOW
def get_sliding_window() -> Optional[SlidingWindow]:
global SLIDING_WINDOW
return SLIDING_WINDOW

View File

@ -604,15 +604,16 @@ try:
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"], max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0, base=10000.0,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1 beta_slow=1,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -645,15 +646,16 @@ try:
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"], max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0, base=10000.0,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1 beta_slow=1,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -734,19 +736,27 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
import math import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) def find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
) / (2 * math.log(base))
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): def find_correction_range(
low = math.floor(find_correction_dim( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
low_rot, dim, base, max_position_embeddings)) ):
high = math.ceil(find_correction_dim( low = math.floor(
high_rot, dim, base, max_position_embeddings)) find_correction_dim(low_rot, dim, base, max_position_embeddings)
return max(low, 0), min(high, dim-1) # Clamp values just in case )
high = math.ceil(
find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim): def linear_ramp_mask(min, max, dim):
if min == max: if min == max:
@ -762,7 +772,19 @@ try:
return 0.1 * math.log(scale) + 1.0 return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): def __init__(
self,
dim,
max_position_embeddings,
base,
device,
scaling_factor,
*,
extrapolation_factor,
attn_factor,
beta_fast,
beta_slow,
):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
self.dim = dim self.dim = dim
@ -772,7 +794,9 @@ try:
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -788,13 +812,26 @@ try:
) )
freqs = 1.0 / inv_freq_extrapolation freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) low, high = find_correction_range(
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation self.beta_fast,
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask self.beta_slow,
self.dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = (
1
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)

View File

@ -16,7 +16,7 @@ class Weights:
dtype, dtype,
process_group, process_group,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None prefix: Optional[str] = None,
): ):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
@ -213,7 +213,8 @@ class Weights:
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]

View File

@ -21,14 +21,14 @@ def main():
block = [] block = []
for line in lines: for line in lines:
if line.startswith(" -") or line.startswith(" -"): if line.startswith(" -") or line.startswith(" -"):
rendered_block = '\n'.join(block) rendered_block = "\n".join(block)
if header: if header:
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
else: else:
final_doc += f"```shell\n{rendered_block}\n```\n" final_doc += f"```shell\n{rendered_block}\n```\n"
block = [] block = []
tokens = line.split("<") tokens = line.split("<")
if len(tokens)>1: if len(tokens) > 1:
header = tokens[-1][:-1] header = tokens[-1][:-1]
else: else:
header = line.split("--")[-1] header = line.split("--")[-1]
@ -36,7 +36,7 @@ def main():
block.append(line) block.append(line)
rendered_block = '\n'.join(block) rendered_block = "\n".join(block)
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
block = [] block = []