parent
e4b26aa10b
commit
31e2253ae7
|
@ -1,9 +1,9 @@
|
|||
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
|
||||
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
|
||||
|
||||
flash-attention:
|
||||
# Clone flash attention
|
||||
pip install packaging
|
||||
git clone https://github.com/OlivierDehaene/flash-attention.git
|
||||
git clone https://github.com/HazyResearch/flash-attention.git
|
||||
|
||||
build-flash-attention: flash-attention
|
||||
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "0.13.3"
|
||||
huggingface-hub = "^0.14.1"
|
||||
transformers = "^4.29.2"
|
||||
transformers = "4.29.2"
|
||||
einops = "^0.6.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
numpy==1.25.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
|
@ -27,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
|||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
|
|
|
@ -135,8 +135,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -158,17 +157,15 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -261,8 +258,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -276,8 +272,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -329,8 +324,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -352,8 +346,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -381,8 +374,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -393,8 +385,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
|
|
@ -123,8 +123,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -146,17 +145,15 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -246,8 +243,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -261,8 +257,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
ln1_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -286,8 +281,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -341,8 +335,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -364,8 +357,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -391,8 +383,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -403,8 +394,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
hidden_states = self.gpt_neox(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
|
|
@ -144,8 +144,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -176,7 +175,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
if cu_seqlen_prefill is not None:
|
||||
if self.num_heads_kv == 1:
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
@ -187,10 +186,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -276,8 +273,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -311,7 +307,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
if cu_seqlen_prefill is not None:
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
kv.unsqueeze(2)
|
||||
|
@ -325,10 +321,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
attn_output,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -428,8 +422,7 @@ class FlashRWLayer(nn.Module):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -443,8 +436,7 @@ class FlashRWLayer(nn.Module):
|
|||
ln_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -466,8 +458,7 @@ class FlashRWLayer(nn.Module):
|
|||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -516,8 +507,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -532,8 +522,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||
ln_attn,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -597,8 +586,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -620,8 +608,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
residual,
|
||||
cos,
|
||||
sin,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -648,8 +635,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -660,8 +646,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
|
|
@ -232,8 +232,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -259,7 +258,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if start_seq_prefill is not None:
|
||||
if cu_seqlen_prefill is not None:
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
|
@ -269,10 +268,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
@ -357,8 +354,7 @@ class Block(nn.Module):
|
|||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -369,8 +365,7 @@ class Block(nn.Module):
|
|||
|
||||
hidden_states = self.attn(
|
||||
hidden_states,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -423,8 +418,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -441,8 +435,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
|
@ -467,8 +460,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
|
@ -479,8 +471,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
start_seq_prefill,
|
||||
end_seq_prefill,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
|
|
|
@ -121,10 +121,10 @@ class FlashCausalLMBatch(Batch):
|
|||
input_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
|
||||
# tensor of length b holding starting offset of each sequence, only used in prefill
|
||||
start_seq_prefill: Optional[torch.Tensor]
|
||||
# tensor of length b holding ending offset of each sequence, only used in prefill
|
||||
end_seq_prefill: Optional[torch.Tensor]
|
||||
# Flash Attention values
|
||||
|
||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||
|
||||
# Paged Attention values
|
||||
|
||||
|
@ -197,8 +197,7 @@ class FlashCausalLMBatch(Batch):
|
|||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
start_seq_prefill = []
|
||||
end_seq_prefill = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
|
@ -250,8 +249,7 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
start_seq_prefill.append(cumulative_length)
|
||||
end_seq_prefill.append(cumulative_length + input_length)
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
|
@ -329,11 +327,8 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
|
||||
start_seq_prefill = torch.tensor(
|
||||
start_seq_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
end_seq_prefill = torch.tensor(
|
||||
end_seq_prefill, device=device, dtype=torch.int32
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
|
@ -345,9 +340,9 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = end_seq_prefill - 1
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = end_seq_prefill - 1
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
|
@ -363,8 +358,7 @@ class FlashCausalLMBatch(Batch):
|
|||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=start_seq_prefill,
|
||||
end_seq_prefill=end_seq_prefill,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
|
@ -504,8 +498,7 @@ class FlashCausalLMBatch(Batch):
|
|||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=None,
|
||||
end_seq_prefill=None,
|
||||
cu_seqlen_prefill=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=None,
|
||||
|
@ -652,8 +645,7 @@ class FlashCausalLMBatch(Batch):
|
|||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=None,
|
||||
end_seq_prefill=None,
|
||||
cu_seqlen_prefill=None,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=None,
|
||||
|
@ -750,8 +742,7 @@ class FlashCausalLM(Model):
|
|||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
start_seq_prefill: Optional[torch.Tensor],
|
||||
end_seq_prefill: Optional[torch.Tensor],
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
|
@ -764,8 +755,7 @@ class FlashCausalLM(Model):
|
|||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
start_seq_prefill=start_seq_prefill,
|
||||
end_seq_prefill=end_seq_prefill,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=CACHE_MANAGER.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
|
@ -778,7 +768,7 @@ class FlashCausalLM(Model):
|
|||
def generate_token(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
prefill = batch.start_seq_prefill is not None
|
||||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
if batch.needed_blocks_slots:
|
||||
|
@ -788,8 +778,7 @@ class FlashCausalLM(Model):
|
|||
out = self.forward(
|
||||
batch.input_ids,
|
||||
batch.position_ids,
|
||||
batch.start_seq_prefill,
|
||||
batch.end_seq_prefill,
|
||||
batch.cu_seqlen_prefill,
|
||||
batch.block_tables_tensor,
|
||||
batch.slots[batch.slot_indices],
|
||||
batch.input_lengths_tensor,
|
||||
|
@ -815,10 +804,9 @@ class FlashCausalLM(Model):
|
|||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
|
||||
# We do not need start_seq_prefill and end_seq_prefill anymore
|
||||
batch.start_seq_prefill = None
|
||||
batch.end_seq_prefill = None
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_position_ids = batch.position_ids
|
||||
|
|
|
@ -66,7 +66,9 @@ class MPTSharded(CausalLM):
|
|||
if local_path.exists():
|
||||
filename = str(local_path.resolve())
|
||||
else:
|
||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
|
|
|
@ -359,7 +359,7 @@ try:
|
|||
def __init__(self, inv_freq):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
|
Loading…
Reference in New Issue