feat: support attention sinks
This commit is contained in:
parent
3c373dcc53
commit
cc36128cda
|
@ -333,6 +333,16 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
rope_factor: Option<f32>,
|
||||
|
||||
/// Sliding Window will only be used by flash attention optimized models
|
||||
/// Limit the Paged Attention context window size
|
||||
#[clap(long, env)]
|
||||
sliding_window: Option<u32>,
|
||||
|
||||
/// If `sliding_window` is set, always keep the first `attention_sinks` tokens in the context
|
||||
/// See: [Efficient Streaming Language Models with Attention Sinks](https://arxiv.org/abs/2309.17453)
|
||||
#[clap(long, env)]
|
||||
attention_sinks: Option<u32>,
|
||||
|
||||
/// Outputs the logs in JSON format (useful for telemetry)
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
|
@ -390,6 +400,8 @@ fn shard_manager(
|
|||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
sliding_window: Option<u32>,
|
||||
attention_sinks: Option<u32>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
|
@ -495,6 +507,17 @@ fn shard_manager(
|
|||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||
}
|
||||
|
||||
// Detect sliding window
|
||||
// Sending as env instead of CLI args to not bloat everything
|
||||
// those only can be used by flash attention models, so passing information around
|
||||
// for all models will complexify code unnecessarily
|
||||
if let Some(sliding_window) = sliding_window {
|
||||
envs.push(("SLIDING_WINDOW".into(), sliding_window.to_string().into()));
|
||||
}
|
||||
if let Some(attention_sinks) = attention_sinks {
|
||||
envs.push(("ATTENTION_SINKS".into(), attention_sinks.to_string().into()));
|
||||
}
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
|
@ -891,6 +914,8 @@ fn spawn_shards(
|
|||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
let sliding_window = args.sliding_window;
|
||||
let attention_sinks = args.attention_sinks;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
|
@ -911,6 +936,8 @@ fn spawn_shards(
|
|||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
sliding_window,
|
||||
attention_sinks,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
|
|
|
@ -297,7 +297,7 @@ def get_model(
|
|||
raise ValueError("awq quantization is not supported for AutoModel")
|
||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||
raise ValueError("4bit quantization is not supported for AutoModel")
|
||||
elif (quantize == "eetq"):
|
||||
elif quantize == "eetq":
|
||||
raise ValueError("Eetq quantization is not supported for AutoModel")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
|
|
|
@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
prefix="transformer",
|
||||
)
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id)
|
||||
|
|
|
@ -15,12 +15,14 @@ class CacheManager:
|
|||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
attention_sinks: int,
|
||||
repeat_slots: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
self.attention_sinks = attention_sinks
|
||||
self.repeat_slots = repeat_slots
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
|
@ -82,9 +84,24 @@ class CacheManager:
|
|||
|
||||
# Repeat slots in the case of context sliding window
|
||||
if needed_slots > len(all_slots) and self.repeat_slots:
|
||||
repeats = math.ceil(needed_slots / len(all_slots))
|
||||
repeats = math.ceil(
|
||||
needed_slots / (len(all_slots) - self.attention_sinks)
|
||||
)
|
||||
|
||||
if self.attention_sinks > 0:
|
||||
# Remove attention sinks from the repeat to not override them
|
||||
all_slots = torch.cat(
|
||||
[
|
||||
all_slots,
|
||||
all_slots[self.attention_sinks :].repeat(repeats - 1),
|
||||
]
|
||||
)
|
||||
else:
|
||||
all_slots = all_slots.repeat(repeats)
|
||||
|
||||
elif needed_slots > len(all_slots):
|
||||
raise RuntimeError("Out of available slots. This is a bug")
|
||||
|
||||
allocated_slots = all_slots[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
|
@ -112,6 +129,7 @@ def set_cache_manager(
|
|||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
attention_sinks: int,
|
||||
repeat_slots: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
|
@ -122,7 +140,14 @@ def set_cache_manager(
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
||||
num_blocks,
|
||||
num_layers,
|
||||
num_heads,
|
||||
head_size,
|
||||
attention_sinks,
|
||||
repeat_slots,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
return CACHE_MANAGER
|
||||
|
||||
|
|
|
@ -254,6 +254,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
|
@ -269,8 +270,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
|
@ -376,6 +382,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -390,6 +397,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
|
@ -442,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
|
@ -464,6 +473,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
@ -492,8 +502,19 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
sliding_window: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif sliding_window != -1:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(sliding_window, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -503,6 +524,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -201,9 +201,6 @@ class MistralAttention(torch.nn.Module):
|
|||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else 0
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
@ -252,6 +249,7 @@ class MistralAttention(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
sliding_window,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
|
@ -290,7 +288,7 @@ class MistralAttention(torch.nn.Module):
|
|||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
window_size_left=sliding_window,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -381,6 +379,7 @@ class MistralLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
sliding_window,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -396,6 +395,7 @@ class MistralLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
sliding_window,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
|
@ -449,6 +449,7 @@ class MistralModel(torch.nn.Module):
|
|||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
sliding_window: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
|
@ -472,6 +473,7 @@ class MistralModel(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
sliding_window,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
@ -489,9 +491,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
if self.max_past is None:
|
||||
raise ValueError("max_past cannot be None")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -504,16 +503,17 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
sliding_window: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
else:
|
||||
elif sliding_window != -1:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(self.max_past, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
||||
max_s = min(sliding_window, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -525,6 +525,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
sliding_window,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -133,6 +133,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
|
@ -141,20 +142,28 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
query, kv = qkv.split([1, 2], dim=1)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
|
@ -166,7 +175,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
|
@ -245,6 +254,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
if self.use_parallel_residual:
|
||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||
|
@ -259,6 +269,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||
|
@ -283,6 +294,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
|
@ -337,6 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
|
||||
|
@ -359,6 +372,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||
|
@ -385,8 +399,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
sliding_window: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif sliding_window != -1:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(sliding_window, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -396,6 +421,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -174,6 +174,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
|
@ -191,8 +192,13 @@ class FlashRWAttention(torch.nn.Module):
|
|||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output
|
||||
|
@ -294,6 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
|
@ -310,9 +317,14 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
kv[:, :, 0].contiguous(),
|
||||
kv[:, :, 1].contiguous(),
|
||||
kv_to_cache[:, :, 0].contiguous(),
|
||||
kv_to_cache[:, :, 1].contiguous(),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
slots,
|
||||
|
@ -428,6 +440,7 @@ class FlashRWLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
@ -442,6 +455,7 @@ class FlashRWLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(ln_hidden_states)
|
||||
|
@ -464,6 +478,7 @@ class FlashRWLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
|
@ -513,6 +528,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
|
@ -528,6 +544,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
|
@ -589,6 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
|
||||
|
@ -611,6 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
@ -638,8 +657,19 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
sliding_window: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif sliding_window != -1:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(sliding_window, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -649,6 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -246,6 +246,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.c_attn(hidden_states)
|
||||
|
||||
|
@ -258,8 +259,13 @@ class FlashMQAttention(torch.nn.Module):
|
|||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = key_value[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = key_value
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output
|
||||
|
@ -367,6 +373,7 @@ class Block(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||
|
@ -420,6 +427,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
|
||||
|
@ -437,6 +445,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
@ -462,8 +471,19 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
sliding_window: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif sliding_window != -1:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(sliding_window, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -473,6 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
|||
image = image_url_or_urls
|
||||
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5))
|
||||
response = requests.get(
|
||||
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
|
||||
)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
else:
|
||||
|
|
|
@ -9,7 +9,7 @@ import numpy as np
|
|||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
|
@ -24,6 +24,10 @@ from text_generation_server.models.cache_manager import (
|
|||
set_cache_manager,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
from text_generation_server.models.sliding_window import (
|
||||
set_sliding_window_from_env,
|
||||
get_sliding_window,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
@ -47,6 +51,12 @@ class FlashCausalLMBatch(Batch):
|
|||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||
|
||||
# Sliding window values
|
||||
|
||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||
prefill_cache_indices: Optional[torch.Tensor]
|
||||
|
||||
# Paged Attention values
|
||||
|
||||
# Set when creating the batch
|
||||
|
@ -109,6 +119,8 @@ class FlashCausalLMBatch(Batch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window = get_sliding_window()
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
|
@ -124,6 +136,7 @@ class FlashCausalLMBatch(Batch):
|
|||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
|
@ -187,8 +200,15 @@ class FlashCausalLMBatch(Batch):
|
|||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
total_tokens = input_length + max_new_tokens - 1
|
||||
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
|
||||
# If using sliding window
|
||||
if sliding_window is not None:
|
||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||
needed_blocks = min(needed_blocks, sliding_window.blocks)
|
||||
blocks += needed_blocks
|
||||
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
|
@ -199,6 +219,32 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# If using sliding window
|
||||
if sliding_window is not None:
|
||||
# Start of the sliding window cache
|
||||
start_offset = max(
|
||||
0,
|
||||
input_length - sliding_window.size + sliding_window.attention_sinks,
|
||||
)
|
||||
|
||||
if sliding_window.attention_sinks > 0 and start_offset > 0:
|
||||
# Attention sinks indices
|
||||
request_attention_sinks_cache_indices = torch.arange(
|
||||
cumulative_length,
|
||||
cumulative_length
|
||||
+ min(sliding_window.attention_sinks, start_offset),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
prefill_cache_indices.append(request_attention_sinks_cache_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + start_offset,
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
|
@ -252,12 +298,26 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
|
||||
if len(prefill_cache_indices) > 1:
|
||||
prefill_cache_indices = (
|
||||
torch.cat(prefill_cache_indices) if prefill_cache_indices else None
|
||||
)
|
||||
else:
|
||||
prefill_cache_indices = (
|
||||
prefill_cache_indices[0] if prefill_cache_indices else None
|
||||
)
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device)
|
||||
if prefill_cache_indices is not None
|
||||
else None
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
input_lengths, dtype=torch.int32, device=device
|
||||
|
@ -309,6 +369,7 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -425,7 +486,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
return type(self)(
|
||||
return FlashCausalLMBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
|
@ -454,6 +515,7 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -611,6 +673,7 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
@ -636,11 +699,11 @@ class FlashCausalLM(Model):
|
|||
device: torch.device,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
set_sliding_window_from_env()
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -650,7 +713,6 @@ class FlashCausalLM(Model):
|
|||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -658,6 +720,8 @@ class FlashCausalLM(Model):
|
|||
return FlashCausalLMBatch
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
sliding_window = get_sliding_window()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
|
@ -665,7 +729,8 @@ class FlashCausalLM(Model):
|
|||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.sliding_window is not None,
|
||||
sliding_window.attention_sinks if sliding_window is not None else 0,
|
||||
True if sliding_window is not None else False,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
|
@ -705,7 +770,8 @@ class FlashCausalLM(Model):
|
|||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.sliding_window is not None,
|
||||
sliding_window.attention_sinks if sliding_window is not None else 0,
|
||||
True if sliding_window is not None else False,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
|
@ -713,8 +779,10 @@ class FlashCausalLM(Model):
|
|||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
sliding_window = get_sliding_window()
|
||||
|
||||
# Model Forward
|
||||
return self.model.forward(
|
||||
logits = self.model.forward(
|
||||
input_ids=batch.input_ids,
|
||||
position_ids=batch.position_ids,
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||
|
@ -723,8 +791,13 @@ class FlashCausalLM(Model):
|
|||
slots=batch.slots[batch.slot_indices],
|
||||
input_lengths=batch.input_lengths_tensor,
|
||||
max_s=batch.max_seqlen,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
sliding_window=sliding_window.size if sliding_window is not None else -1,
|
||||
lm_head_indices=batch.prefill_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
|
|
|
@ -1,21 +1,14 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from typing import Optional, Tuple, Type
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
set_cache_manager,
|
||||
from text_generation_server.models.sliding_window import (
|
||||
set_sliding_window,
|
||||
get_sliding_window,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
|
@ -25,255 +18,10 @@ from text_generation_server.utils import (
|
|||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
|
||||
|
||||
# Adds windowing logic to FlashCausalLMBatch
|
||||
@dataclass
|
||||
class FlashMistralBatch(FlashCausalLMBatch):
|
||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
blocks = 0
|
||||
max_seqlen = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
total_tokens = input_length + max_new_tokens - 1
|
||||
|
||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||
needed_blocks = min(
|
||||
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
|
||||
)
|
||||
blocks += needed_blocks
|
||||
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_max_length,
|
||||
cumulative_max_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
if r.prefill_logprobs:
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
cumulative_max_length += total_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_blocks = max(max_blocks, needed_blocks)
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
# Padded all_input_ids_tensor
|
||||
all_input_ids_tensor = np.zeros(
|
||||
(len(all_input_ids), max_length), dtype=np.int64
|
||||
)
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
prefill_cache_indices = prefill_cache_indices.to(device)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||
)
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
block_tables=None,
|
||||
block_tables_tensor=None,
|
||||
slots=None,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
prefill_cu_outlens=prefill_cu_outlens,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
|
||||
|
||||
class FlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
|
@ -284,9 +32,6 @@ class FlashMistral(FlashCausalLM):
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
@ -308,8 +53,7 @@ class FlashMistral(FlashCausalLM):
|
|||
config.quantize = quantize
|
||||
|
||||
# Set context windows
|
||||
SLIDING_WINDOW = config.sliding_window
|
||||
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
set_sliding_window(config.sliding_window, 0)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -331,27 +75,4 @@ class FlashMistral(FlashCausalLM):
|
|||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
logits = self.model.forward(
|
||||
input_ids=batch.input_ids,
|
||||
position_ids=batch.position_ids,
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=batch.block_tables_tensor,
|
||||
slots=batch.slots[batch.slot_indices],
|
||||
input_lengths=batch.input_lengths_tensor,
|
||||
max_s=batch.max_seqlen,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=batch.prefill_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
||||
|
|
|
@ -3,10 +3,11 @@ import torch
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.models.sliding_window import get_sliding_window
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
@ -21,7 +22,6 @@ class Model(ABC):
|
|||
device: torch.device,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
|
@ -31,7 +31,6 @@ class Model(ABC):
|
|||
self.device = device
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
|
@ -42,14 +41,15 @@ class Model(ABC):
|
|||
|
||||
@property
|
||||
def info(self) -> InfoResponse:
|
||||
if self.requires_padding and self.sliding_window is not None:
|
||||
sliding_window = get_sliding_window()
|
||||
if self.requires_padding and sliding_window is not None:
|
||||
raise NotImplementedError("sliding_window is not implemented with padding")
|
||||
|
||||
return InfoResponse(
|
||||
requires_padding=self.requires_padding,
|
||||
dtype=str(self.dtype),
|
||||
device_type=self.device.type,
|
||||
window_size=self.sliding_window,
|
||||
window_size=sliding_window.size if sliding_window is not None else None,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
import math
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||
|
||||
SLIDING_WINDOW: Optional["SlidingWindow"] = None
|
||||
|
||||
|
||||
class SlidingWindow:
|
||||
def __init__(self, size: int, attention_sinks: int):
|
||||
self.size = size
|
||||
self.blocks = math.ceil(size / BLOCK_SIZE)
|
||||
self.attention_sinks = attention_sinks
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> Optional["SlidingWindow"]:
|
||||
sliding_window_env = os.getenv("SLIDING_WINDOW", None)
|
||||
if sliding_window_env is not None:
|
||||
return cls(int(sliding_window_env), int(os.getenv("ATTENTION_SINKS", 0)))
|
||||
return None
|
||||
|
||||
|
||||
def set_sliding_window(size: int, attention_sinks: int) -> SlidingWindow:
|
||||
global SLIDING_WINDOW
|
||||
SLIDING_WINDOW = SlidingWindow(size, attention_sinks)
|
||||
return SLIDING_WINDOW
|
||||
|
||||
|
||||
def set_sliding_window_from_env() -> Optional[SlidingWindow]:
|
||||
global SLIDING_WINDOW
|
||||
env_sliding_window = SlidingWindow.from_env()
|
||||
if env_sliding_window is not None:
|
||||
SLIDING_WINDOW = env_sliding_window
|
||||
return SLIDING_WINDOW
|
||||
|
||||
|
||||
def get_sliding_window() -> Optional[SlidingWindow]:
|
||||
global SLIDING_WINDOW
|
||||
return SLIDING_WINDOW
|
|
@ -604,15 +604,16 @@ try:
|
|||
elif rope_scaling["type"] == "yarn":
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||
max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
attn_factor=1,
|
||||
beta_fast=32,
|
||||
beta_slow=1
|
||||
|
||||
beta_slow=1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
@ -645,15 +646,16 @@ try:
|
|||
elif rope_scaling["type"] == "yarn":
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||
max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
base=10000.0,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
attn_factor=1,
|
||||
beta_fast=32,
|
||||
beta_slow=1
|
||||
|
||||
beta_slow=1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
@ -734,18 +736,26 @@ try:
|
|||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
import math
|
||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
|
||||
|
||||
def find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (
|
||||
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
|
||||
) / (2 * math.log(base))
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
||||
low = math.floor(find_correction_dim(
|
||||
low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(find_correction_dim(
|
||||
high_rot, dim, base, max_position_embeddings))
|
||||
def find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
|
@ -762,7 +772,19 @@ try:
|
|||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
device,
|
||||
scaling_factor,
|
||||
*,
|
||||
extrapolation_factor,
|
||||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
|
@ -772,7 +794,9 @@ try:
|
|||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(
|
||||
get_mscale(self.scaling_factor) * self.attn_factor
|
||||
) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
|
@ -788,13 +812,26 @@ try:
|
|||
)
|
||||
freqs = 1.0 / inv_freq_extrapolation
|
||||
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
|
||||
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
low, high = find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = (
|
||||
1
|
||||
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
|
||||
self.inv_freq = inv_freq
|
||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
self.mscale = float(
|
||||
get_mscale(self.scaling_factor) * self.attn_factor
|
||||
) # Get n-d magnitude scaling corrected for interpolation
|
||||
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
|
|
|
@ -16,7 +16,7 @@ class Weights:
|
|||
dtype,
|
||||
process_group,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
prefix: Optional[str] = None
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
|
@ -213,6 +213,7 @@ class Weights:
|
|||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
|
|
|
@ -21,7 +21,7 @@ def main():
|
|||
block = []
|
||||
for line in lines:
|
||||
if line.startswith(" -") or line.startswith(" -"):
|
||||
rendered_block = '\n'.join(block)
|
||||
rendered_block = "\n".join(block)
|
||||
if header:
|
||||
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
||||
else:
|
||||
|
@ -36,7 +36,7 @@ def main():
|
|||
|
||||
block.append(line)
|
||||
|
||||
rendered_block = '\n'.join(block)
|
||||
rendered_block = "\n".join(block)
|
||||
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
||||
block = []
|
||||
|
||||
|
|
Loading…
Reference in New Issue