feat: support attention sinks
This commit is contained in:
parent
3c373dcc53
commit
cc36128cda
|
@ -333,6 +333,16 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
|
|
||||||
|
/// Sliding Window will only be used by flash attention optimized models
|
||||||
|
/// Limit the Paged Attention context window size
|
||||||
|
#[clap(long, env)]
|
||||||
|
sliding_window: Option<u32>,
|
||||||
|
|
||||||
|
/// If `sliding_window` is set, always keep the first `attention_sinks` tokens in the context
|
||||||
|
/// See: [Efficient Streaming Language Models with Attention Sinks](https://arxiv.org/abs/2309.17453)
|
||||||
|
#[clap(long, env)]
|
||||||
|
attention_sinks: Option<u32>,
|
||||||
|
|
||||||
/// Outputs the logs in JSON format (useful for telemetry)
|
/// Outputs the logs in JSON format (useful for telemetry)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
|
@ -390,6 +400,8 @@ fn shard_manager(
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
|
sliding_window: Option<u32>,
|
||||||
|
attention_sinks: Option<u32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
|
@ -495,6 +507,17 @@ fn shard_manager(
|
||||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect sliding window
|
||||||
|
// Sending as env instead of CLI args to not bloat everything
|
||||||
|
// those only can be used by flash attention models, so passing information around
|
||||||
|
// for all models will complexify code unnecessarily
|
||||||
|
if let Some(sliding_window) = sliding_window {
|
||||||
|
envs.push(("SLIDING_WINDOW".into(), sliding_window.to_string().into()));
|
||||||
|
}
|
||||||
|
if let Some(attention_sinks) = attention_sinks {
|
||||||
|
envs.push(("ATTENTION_SINKS".into(), attention_sinks.to_string().into()));
|
||||||
|
}
|
||||||
|
|
||||||
// If huggingface_hub_cache is some, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
|
@ -891,6 +914,8 @@ fn spawn_shards(
|
||||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
let rope_scaling = args.rope_scaling;
|
let rope_scaling = args.rope_scaling;
|
||||||
let rope_factor = args.rope_factor;
|
let rope_factor = args.rope_factor;
|
||||||
|
let sliding_window = args.sliding_window;
|
||||||
|
let attention_sinks = args.attention_sinks;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -911,6 +936,8 @@ fn spawn_shards(
|
||||||
cuda_memory_fraction,
|
cuda_memory_fraction,
|
||||||
rope_scaling,
|
rope_scaling,
|
||||||
rope_factor,
|
rope_factor,
|
||||||
|
sliding_window,
|
||||||
|
attention_sinks,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
|
|
|
@ -297,7 +297,7 @@ def get_model(
|
||||||
raise ValueError("awq quantization is not supported for AutoModel")
|
raise ValueError("awq quantization is not supported for AutoModel")
|
||||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||||
raise ValueError("4bit quantization is not supported for AutoModel")
|
raise ValueError("4bit quantization is not supported for AutoModel")
|
||||||
elif (quantize == "eetq"):
|
elif quantize == "eetq":
|
||||||
raise ValueError("Eetq quantization is not supported for AutoModel")
|
raise ValueError("Eetq quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
|
|
@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
prefix="transformer",
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
|
@ -15,12 +15,14 @@ class CacheManager:
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
attention_sinks: int,
|
||||||
repeat_slots: bool,
|
repeat_slots: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.block_size = BLOCK_SIZE
|
self.block_size = BLOCK_SIZE
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
|
self.attention_sinks = attention_sinks
|
||||||
self.repeat_slots = repeat_slots
|
self.repeat_slots = repeat_slots
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
@ -82,8 +84,23 @@ class CacheManager:
|
||||||
|
|
||||||
# Repeat slots in the case of context sliding window
|
# Repeat slots in the case of context sliding window
|
||||||
if needed_slots > len(all_slots) and self.repeat_slots:
|
if needed_slots > len(all_slots) and self.repeat_slots:
|
||||||
repeats = math.ceil(needed_slots / len(all_slots))
|
repeats = math.ceil(
|
||||||
all_slots = all_slots.repeat(repeats)
|
needed_slots / (len(all_slots) - self.attention_sinks)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.attention_sinks > 0:
|
||||||
|
# Remove attention sinks from the repeat to not override them
|
||||||
|
all_slots = torch.cat(
|
||||||
|
[
|
||||||
|
all_slots,
|
||||||
|
all_slots[self.attention_sinks :].repeat(repeats - 1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
all_slots = all_slots.repeat(repeats)
|
||||||
|
|
||||||
|
elif needed_slots > len(all_slots):
|
||||||
|
raise RuntimeError("Out of available slots. This is a bug")
|
||||||
|
|
||||||
allocated_slots = all_slots[:needed_slots]
|
allocated_slots = all_slots[:needed_slots]
|
||||||
|
|
||||||
|
@ -112,6 +129,7 @@ def set_cache_manager(
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
attention_sinks: int,
|
||||||
repeat_slots: bool,
|
repeat_slots: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
@ -122,7 +140,14 @@ def set_cache_manager(
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
CACHE_MANAGER = CacheManager(
|
CACHE_MANAGER = CacheManager(
|
||||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
num_blocks,
|
||||||
|
num_layers,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
attention_sinks,
|
||||||
|
repeat_slots,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
return CACHE_MANAGER
|
return CACHE_MANAGER
|
||||||
|
|
||||||
|
|
|
@ -254,6 +254,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
|
@ -269,8 +270,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
|
@ -376,6 +382,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -390,6 +397,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -442,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
@ -464,6 +473,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
@ -492,8 +502,19 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
sliding_window: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
elif sliding_window != -1:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
max_s = min(sliding_window, max_s)
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -503,6 +524,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -201,9 +201,6 @@ class MistralAttention(torch.nn.Module):
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
|
||||||
config.sliding_window if config.sliding_window is not None else 0
|
|
||||||
)
|
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
@ -252,6 +249,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
sliding_window,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
|
@ -290,7 +288,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=sliding_window,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -381,6 +379,7 @@ class MistralLayer(nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
sliding_window,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -396,6 +395,7 @@ class MistralLayer(nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -449,6 +449,7 @@ class MistralModel(torch.nn.Module):
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
sliding_window: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
@ -472,6 +473,7 @@ class MistralModel(torch.nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
@ -489,9 +491,6 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
|
||||||
if self.max_past is None:
|
|
||||||
raise ValueError("max_past cannot be None")
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -504,16 +503,17 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
sliding_window: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
slots = slots[prefill_cache_indices]
|
slots = slots[prefill_cache_indices]
|
||||||
else:
|
elif sliding_window != -1:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
max_s = min(self.max_past, max_s)
|
max_s = min(sliding_window, max_s)
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -525,6 +525,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
sliding_window,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -133,6 +133,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
@ -141,20 +142,28 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
|
query, kv = qkv.split([1, 2], dim=1)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attention(
|
||||||
qkv[:, 0],
|
query,
|
||||||
qkv[:, 1],
|
kv[:, 0],
|
||||||
qkv[:, 2],
|
kv[:, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
|
@ -166,7 +175,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
block_size = kv_cache[1].shape[3]
|
block_size = kv_cache[1].shape[3]
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
|
@ -245,6 +254,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
|
@ -259,6 +269,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
|
@ -283,6 +294,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
@ -337,6 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
|
@ -359,6 +372,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
@ -385,8 +399,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
sliding_window: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
elif sliding_window != -1:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
max_s = min(sliding_window, max_s)
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||||
|
|
||||||
hidden_states = self.gpt_neox(
|
hidden_states = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -396,6 +421,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -174,6 +174,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
@ -191,8 +192,13 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
@ -294,6 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
|
@ -310,9 +317,14 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv_to_cache[:, :, 0].contiguous(),
|
||||||
kv[:, :, 1].contiguous(),
|
kv_to_cache[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
slots,
|
slots,
|
||||||
|
@ -428,6 +440,7 @@ class FlashRWLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -442,6 +455,7 @@ class FlashRWLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
|
@ -464,6 +478,7 @@ class FlashRWLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
@ -513,6 +528,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
|
@ -528,6 +544,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
|
@ -589,6 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
@ -611,6 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
@ -638,8 +657,19 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
sliding_window: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
elif sliding_window != -1:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
max_s = min(sliding_window, max_s)
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||||
|
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -649,6 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -246,6 +246,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
|
||||||
|
@ -258,8 +259,13 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = key_value[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = key_value
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
@ -367,6 +373,7 @@ class Block(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
|
@ -420,6 +427,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
|
||||||
|
@ -437,6 +445,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
@ -462,8 +471,19 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
sliding_window: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
elif sliding_window != -1:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
max_s = min(sliding_window, max_s)
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=sliding_window)
|
||||||
|
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -473,6 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||||
image = image_url_or_urls
|
image = image_url_or_urls
|
||||||
|
|
||||||
if image.startswith("http://") or image.startswith("https://"):
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5))
|
response = requests.get(
|
||||||
|
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -9,7 +9,7 @@ import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
|
@ -24,6 +24,10 @@ from text_generation_server.models.cache_manager import (
|
||||||
set_cache_manager,
|
set_cache_manager,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.sliding_window import (
|
||||||
|
set_sliding_window_from_env,
|
||||||
|
get_sliding_window,
|
||||||
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
@ -47,6 +51,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Sliding window values
|
||||||
|
|
||||||
|
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||||
|
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Paged Attention values
|
# Paged Attention values
|
||||||
|
|
||||||
# Set when creating the batch
|
# Set when creating the batch
|
||||||
|
@ -109,6 +119,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
|
sliding_window = get_sliding_window()
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
|
@ -124,6 +136,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
needed_blocks_slots = []
|
needed_blocks_slots = []
|
||||||
start_slots = []
|
start_slots = []
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
|
prefill_cache_indices = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
@ -187,8 +200,15 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
total_tokens = input_length + max_new_tokens - 1
|
||||||
|
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
|
|
||||||
|
# If using sliding window
|
||||||
|
if sliding_window is not None:
|
||||||
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||||
|
needed_blocks = min(needed_blocks, sliding_window.blocks)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
|
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
|
@ -199,6 +219,32 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
|
# If using sliding window
|
||||||
|
if sliding_window is not None:
|
||||||
|
# Start of the sliding window cache
|
||||||
|
start_offset = max(
|
||||||
|
0,
|
||||||
|
input_length - sliding_window.size + sliding_window.attention_sinks,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sliding_window.attention_sinks > 0 and start_offset > 0:
|
||||||
|
# Attention sinks indices
|
||||||
|
request_attention_sinks_cache_indices = torch.arange(
|
||||||
|
cumulative_length,
|
||||||
|
cumulative_length
|
||||||
|
+ min(sliding_window.attention_sinks, start_offset),
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
prefill_cache_indices.append(request_attention_sinks_cache_indices)
|
||||||
|
|
||||||
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
|
request_prefill_cache_indices = torch.arange(
|
||||||
|
cumulative_length + start_offset,
|
||||||
|
cumulative_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
|
@ -252,12 +298,26 @@ class FlashCausalLMBatch(Batch):
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
|
|
||||||
|
if len(prefill_cache_indices) > 1:
|
||||||
|
prefill_cache_indices = (
|
||||||
|
torch.cat(prefill_cache_indices) if prefill_cache_indices else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_cache_indices = (
|
||||||
|
prefill_cache_indices[0] if prefill_cache_indices else None
|
||||||
|
)
|
||||||
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
prefill_cache_indices = (
|
||||||
|
prefill_cache_indices.to(device)
|
||||||
|
if prefill_cache_indices is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
input_lengths_tensor = torch.tensor(
|
input_lengths_tensor = torch.tensor(
|
||||||
input_lengths, dtype=torch.int32, device=device
|
input_lengths, dtype=torch.int32, device=device
|
||||||
|
@ -309,6 +369,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
|
@ -425,7 +486,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
return type(self)(
|
return FlashCausalLMBatch(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
@ -454,6 +515,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -611,6 +673,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -636,11 +699,11 @@ class FlashCausalLM(Model):
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
set_sliding_window_from_env()
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -650,7 +713,6 @@ class FlashCausalLM(Model):
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
sliding_window=sliding_window,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -658,6 +720,8 @@ class FlashCausalLM(Model):
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
|
sliding_window = get_sliding_window()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
cache_manager = set_cache_manager(
|
||||||
|
@ -665,7 +729,8 @@ class FlashCausalLM(Model):
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.sliding_window is not None,
|
sliding_window.attention_sinks if sliding_window is not None else 0,
|
||||||
|
True if sliding_window is not None else False,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
@ -705,7 +770,8 @@ class FlashCausalLM(Model):
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.sliding_window is not None,
|
sliding_window.attention_sinks if sliding_window is not None else 0,
|
||||||
|
True if sliding_window is not None else False,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
@ -713,8 +779,10 @@ class FlashCausalLM(Model):
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
sliding_window = get_sliding_window()
|
||||||
|
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
position_ids=batch.position_ids,
|
position_ids=batch.position_ids,
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||||
|
@ -723,8 +791,13 @@ class FlashCausalLM(Model):
|
||||||
slots=batch.slots[batch.slot_indices],
|
slots=batch.slots[batch.slot_indices],
|
||||||
input_lengths=batch.input_lengths_tensor,
|
input_lengths=batch.input_lengths_tensor,
|
||||||
max_s=batch.max_seqlen,
|
max_s=batch.max_seqlen,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
sliding_window=sliding_window.size if sliding_window is not None else -1,
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
lm_head_indices=batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
return logits
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
|
|
|
@ -1,21 +1,14 @@
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizerFast
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
from text_generation_server.models.sliding_window import (
|
||||||
from text_generation_server.models.cache_manager import (
|
set_sliding_window,
|
||||||
get_cache_manager,
|
get_sliding_window,
|
||||||
set_cache_manager,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
|
@ -25,255 +18,10 @@ from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
HeterogeneousNextTokenChooser,
|
|
||||||
StoppingCriteria,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
# Will be set in init
|
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
|
||||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
# Adds windowing logic to FlashCausalLMBatch
|
|
||||||
@dataclass
|
|
||||||
class FlashMistralBatch(FlashCausalLMBatch):
|
|
||||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
|
||||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "FlashCausalLMBatch":
|
|
||||||
global SLIDING_WINDOW
|
|
||||||
global SLIDING_WINDOW_BLOCKS
|
|
||||||
|
|
||||||
batch_inputs = []
|
|
||||||
max_truncation = 0
|
|
||||||
for r in pb.requests:
|
|
||||||
batch_inputs.append(r.inputs)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
position_ids = []
|
|
||||||
cu_seqlen_prefill = [0]
|
|
||||||
needed_blocks_slots = []
|
|
||||||
start_slots = []
|
|
||||||
slot_indices = []
|
|
||||||
prefill_cache_indices = []
|
|
||||||
|
|
||||||
input_lengths = []
|
|
||||||
prefix_offsets = []
|
|
||||||
read_offsets = []
|
|
||||||
all_input_ids = []
|
|
||||||
requests_idx_mapping = {}
|
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
|
||||||
no_prefill_logprobs = True
|
|
||||||
prefill_head_indices = []
|
|
||||||
prefill_next_token_indices = []
|
|
||||||
prefill_cu_outlens = [0]
|
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
|
||||||
stopping_criterias = []
|
|
||||||
top_n_tokens = []
|
|
||||||
|
|
||||||
# Cumulative length
|
|
||||||
cumulative_length = 0
|
|
||||||
cumulative_max_length = 0
|
|
||||||
prefill_out_cumulative_length = 0
|
|
||||||
|
|
||||||
blocks = 0
|
|
||||||
max_seqlen = 0
|
|
||||||
max_length = 0
|
|
||||||
max_blocks = 0
|
|
||||||
|
|
||||||
# Parse batch
|
|
||||||
for i, (r, tokenized_input) in enumerate(
|
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
|
||||||
):
|
|
||||||
# request id -> idx in list mapping
|
|
||||||
requests_idx_mapping[r.id] = i
|
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
|
||||||
input_lengths.append(input_length)
|
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
|
||||||
read_offsets.append(input_length)
|
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
|
||||||
|
|
||||||
# Position ids
|
|
||||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
|
||||||
position_ids.append(request_position_ids)
|
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
|
||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
|
||||||
r.stopping_parameters, tokenizer
|
|
||||||
)
|
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
|
||||||
stopping_criterias.append(stopping_criteria)
|
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
|
||||||
|
|
||||||
# Paged attention
|
|
||||||
# Remove one as the first token des not have a past
|
|
||||||
total_tokens = input_length + max_new_tokens - 1
|
|
||||||
|
|
||||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
|
||||||
needed_blocks = min(
|
|
||||||
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
|
|
||||||
)
|
|
||||||
blocks += needed_blocks
|
|
||||||
|
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
|
||||||
start_slots.append(cumulative_max_length)
|
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
|
||||||
cumulative_max_length,
|
|
||||||
cumulative_max_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
slot_indices.append(request_slot_indices)
|
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
|
||||||
request_prefill_cache_indices = torch.arange(
|
|
||||||
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
|
||||||
cumulative_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
|
||||||
|
|
||||||
if r.prefill_logprobs:
|
|
||||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
|
||||||
prefill_next_token_indices.append(
|
|
||||||
prefill_out_cumulative_length + input_length - 1
|
|
||||||
)
|
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
|
||||||
prefill_out_cumulative_length += input_length
|
|
||||||
else:
|
|
||||||
prefill_head_indices.append(
|
|
||||||
torch.tensor(
|
|
||||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
|
||||||
prefill_out_cumulative_length += 1
|
|
||||||
|
|
||||||
# Update
|
|
||||||
cumulative_length += input_length
|
|
||||||
cumulative_max_length += total_tokens
|
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
|
||||||
next_token_chooser_parameters, dtype, device
|
|
||||||
)
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
|
||||||
|
|
||||||
# Padded all_input_ids_tensor
|
|
||||||
all_input_ids_tensor = np.zeros(
|
|
||||||
(len(all_input_ids), max_length), dtype=np.int64
|
|
||||||
)
|
|
||||||
for i, input_ids in enumerate(all_input_ids):
|
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
|
||||||
|
|
||||||
# Create tensors on device
|
|
||||||
all_input_ids_tensor = torch.tensor(
|
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
|
||||||
position_ids = torch.cat(position_ids)
|
|
||||||
slot_indices = torch.cat(slot_indices)
|
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
|
||||||
else:
|
|
||||||
input_ids = all_input_ids[0]
|
|
||||||
position_ids = position_ids[0]
|
|
||||||
slot_indices = slot_indices[0]
|
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
|
||||||
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids = position_ids.to(device)
|
|
||||||
slot_indices = slot_indices.to(device)
|
|
||||||
prefill_cache_indices = prefill_cache_indices.to(device)
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
|
||||||
input_lengths_tensor = torch.tensor(
|
|
||||||
input_lengths, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
|
||||||
prefill_head_indices = None
|
|
||||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
|
||||||
elif no_prefill_logprobs:
|
|
||||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
|
||||||
prefill_next_token_indices = None
|
|
||||||
else:
|
|
||||||
prefill_head_indices = torch.tensor(
|
|
||||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
prefill_next_token_indices = torch.tensor(
|
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
top_n_tokens_tensor = torch.tensor(
|
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
batch_id=pb.id,
|
|
||||||
requests=pb.requests,
|
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
|
||||||
needed_blocks_slots=needed_blocks_slots,
|
|
||||||
block_tables=None,
|
|
||||||
block_tables_tensor=None,
|
|
||||||
slots=None,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
prefill_head_indices=prefill_head_indices,
|
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
|
||||||
prefix_offsets=prefix_offsets,
|
|
||||||
read_offsets=read_offsets,
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
|
||||||
next_token_chooser=next_token_chooser,
|
|
||||||
stopping_criterias=stopping_criterias,
|
|
||||||
top_n_tokens=top_n_tokens,
|
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
||||||
blocks=blocks,
|
|
||||||
max_blocks=max_blocks,
|
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(FlashCausalLM):
|
class FlashMistral(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -284,9 +32,6 @@ class FlashMistral(FlashCausalLM):
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
global SLIDING_WINDOW
|
|
||||||
global SLIDING_WINDOW_BLOCKS
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -308,8 +53,7 @@ class FlashMistral(FlashCausalLM):
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
SLIDING_WINDOW = config.sliding_window
|
set_sliding_window(config.sliding_window, 0)
|
||||||
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -331,27 +75,4 @@ class FlashMistral(FlashCausalLM):
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
sliding_window=config.sliding_window,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
|
||||||
return FlashMistralBatch
|
|
||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Model Forward
|
|
||||||
logits = self.model.forward(
|
|
||||||
input_ids=batch.input_ids,
|
|
||||||
position_ids=batch.position_ids,
|
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
|
||||||
block_tables=batch.block_tables_tensor,
|
|
||||||
slots=batch.slots[batch.slot_indices],
|
|
||||||
input_lengths=batch.input_lengths_tensor,
|
|
||||||
max_s=batch.max_seqlen,
|
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
|
||||||
)
|
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
return logits
|
|
||||||
|
|
|
@ -3,10 +3,11 @@ import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
from text_generation_server.models.sliding_window import get_sliding_window
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
@ -21,7 +22,6 @@ class Model(ABC):
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -31,7 +31,6 @@ class Model(ABC):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.sliding_window = sliding_window
|
|
||||||
|
|
||||||
self.has_position_ids = (
|
self.has_position_ids = (
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
|
@ -42,14 +41,15 @@ class Model(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
if self.requires_padding and self.sliding_window is not None:
|
sliding_window = get_sliding_window()
|
||||||
|
if self.requires_padding and sliding_window is not None:
|
||||||
raise NotImplementedError("sliding_window is not implemented with padding")
|
raise NotImplementedError("sliding_window is not implemented with padding")
|
||||||
|
|
||||||
return InfoResponse(
|
return InfoResponse(
|
||||||
requires_padding=self.requires_padding,
|
requires_padding=self.requires_padding,
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
window_size=sliding_window.size if sliding_window is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||||
|
|
||||||
|
SLIDING_WINDOW: Optional["SlidingWindow"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingWindow:
|
||||||
|
def __init__(self, size: int, attention_sinks: int):
|
||||||
|
self.size = size
|
||||||
|
self.blocks = math.ceil(size / BLOCK_SIZE)
|
||||||
|
self.attention_sinks = attention_sinks
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> Optional["SlidingWindow"]:
|
||||||
|
sliding_window_env = os.getenv("SLIDING_WINDOW", None)
|
||||||
|
if sliding_window_env is not None:
|
||||||
|
return cls(int(sliding_window_env), int(os.getenv("ATTENTION_SINKS", 0)))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def set_sliding_window(size: int, attention_sinks: int) -> SlidingWindow:
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
SLIDING_WINDOW = SlidingWindow(size, attention_sinks)
|
||||||
|
return SLIDING_WINDOW
|
||||||
|
|
||||||
|
|
||||||
|
def set_sliding_window_from_env() -> Optional[SlidingWindow]:
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
env_sliding_window = SlidingWindow.from_env()
|
||||||
|
if env_sliding_window is not None:
|
||||||
|
SLIDING_WINDOW = env_sliding_window
|
||||||
|
return SLIDING_WINDOW
|
||||||
|
|
||||||
|
|
||||||
|
def get_sliding_window() -> Optional[SlidingWindow]:
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
return SLIDING_WINDOW
|
|
@ -604,15 +604,16 @@ try:
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
base=10000.0,
|
base=10000.0,
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
extrapolation_factor=1,
|
extrapolation_factor=1,
|
||||||
attn_factor=1,
|
attn_factor=1,
|
||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1
|
beta_slow=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -645,15 +646,16 @@ try:
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
base=10000.0,
|
base=10000.0,
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
extrapolation_factor=1,
|
extrapolation_factor=1,
|
||||||
attn_factor=1,
|
attn_factor=1,
|
||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1
|
beta_slow=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -734,19 +736,27 @@ try:
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
import math
|
import math
|
||||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
|
||||||
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
|
def find_correction_dim(
|
||||||
|
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
|
||||||
|
) / (2 * math.log(base))
|
||||||
|
|
||||||
# Find dim range bounds based on rotations
|
# Find dim range bounds based on rotations
|
||||||
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
def find_correction_range(
|
||||||
low = math.floor(find_correction_dim(
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||||
low_rot, dim, base, max_position_embeddings))
|
):
|
||||||
high = math.ceil(find_correction_dim(
|
low = math.floor(
|
||||||
high_rot, dim, base, max_position_embeddings))
|
find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||||
return max(low, 0), min(high, dim-1) # Clamp values just in case
|
)
|
||||||
|
high = math.ceil(
|
||||||
|
find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||||
|
)
|
||||||
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
def linear_ramp_mask(min, max, dim):
|
def linear_ramp_mask(min, max, dim):
|
||||||
if min == max:
|
if min == max:
|
||||||
|
@ -762,7 +772,19 @@ try:
|
||||||
return 0.1 * math.log(scale) + 1.0
|
return 0.1 * math.log(scale) + 1.0
|
||||||
|
|
||||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
device,
|
||||||
|
scaling_factor,
|
||||||
|
*,
|
||||||
|
extrapolation_factor,
|
||||||
|
attn_factor,
|
||||||
|
beta_fast,
|
||||||
|
beta_slow,
|
||||||
|
):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
@ -772,7 +794,9 @@ try:
|
||||||
self.attn_factor = attn_factor
|
self.attn_factor = attn_factor
|
||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
self.mscale = float(
|
||||||
|
get_mscale(self.scaling_factor) * self.attn_factor
|
||||||
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
|
@ -788,13 +812,26 @@ try:
|
||||||
)
|
)
|
||||||
freqs = 1.0 / inv_freq_extrapolation
|
freqs = 1.0 / inv_freq_extrapolation
|
||||||
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||||
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
|
low, high = find_correction_range(
|
||||||
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
self.beta_fast,
|
||||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
self.beta_slow,
|
||||||
|
self.dim,
|
||||||
|
self.base,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
inv_freq_mask = (
|
||||||
|
1
|
||||||
|
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||||
|
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq = (
|
||||||
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||||
|
+ inv_freq_extrapolation * inv_freq_mask
|
||||||
|
)
|
||||||
|
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
self.mscale = float(
|
||||||
|
get_mscale(self.scaling_factor) * self.attn_factor
|
||||||
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
|
@ -16,7 +16,7 @@ class Weights:
|
||||||
dtype,
|
dtype,
|
||||||
process_group,
|
process_group,
|
||||||
aliases: Optional[Dict[str, List[str]]] = None,
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
prefix: Optional[str] = None
|
prefix: Optional[str] = None,
|
||||||
):
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
@ -213,7 +213,8 @@ class Weights:
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
|
|
||||||
|
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
|
|
@ -21,14 +21,14 @@ def main():
|
||||||
block = []
|
block = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.startswith(" -") or line.startswith(" -"):
|
if line.startswith(" -") or line.startswith(" -"):
|
||||||
rendered_block = '\n'.join(block)
|
rendered_block = "\n".join(block)
|
||||||
if header:
|
if header:
|
||||||
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
||||||
else:
|
else:
|
||||||
final_doc += f"```shell\n{rendered_block}\n```\n"
|
final_doc += f"```shell\n{rendered_block}\n```\n"
|
||||||
block = []
|
block = []
|
||||||
tokens = line.split("<")
|
tokens = line.split("<")
|
||||||
if len(tokens)>1:
|
if len(tokens) > 1:
|
||||||
header = tokens[-1][:-1]
|
header = tokens[-1][:-1]
|
||||||
else:
|
else:
|
||||||
header = line.split("--")[-1]
|
header = line.split("--")[-1]
|
||||||
|
@ -36,7 +36,7 @@ def main():
|
||||||
|
|
||||||
block.append(line)
|
block.append(line)
|
||||||
|
|
||||||
rendered_block = '\n'.join(block)
|
rendered_block = "\n".join(block)
|
||||||
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
||||||
block = []
|
block = []
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue