feat(server): use latest flash attention commit (#543)

@njhill FYI
This commit is contained in:
OlivierDehaene 2023-07-04 20:23:55 +02:00 committed by GitHub
parent e4b26aa10b
commit 31e2253ae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1067 additions and 1064 deletions

View File

@ -1,9 +1,9 @@
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
flash-attention:
# Clone flash attention
pip install packaging
git clone https://github.com/OlivierDehaene/flash-attention.git
git clone https://github.com/HazyResearch/flash-attention.git
build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit)

1909
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "0.13.3"
huggingface-hub = "^0.14.1"
transformers = "^4.29.2"
transformers = "4.29.2"
einops = "^0.6.1"
[tool.poetry.extras]

View File

@ -1,22 +1,23 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.25.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
@ -27,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -135,8 +135,7 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -158,17 +157,15 @@ class FlashLlamaAttention(torch.nn.Module):
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
@ -261,8 +258,7 @@ class FlashLlamaLayer(nn.Module):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -276,8 +272,7 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -329,8 +324,7 @@ class FlashLlamaModel(torch.nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -352,8 +346,7 @@ class FlashLlamaModel(torch.nn.Module):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
@ -381,8 +374,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -393,8 +385,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states = self.model(
input_ids,
position_ids,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,

View File

@ -123,8 +123,7 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -146,17 +145,15 @@ class FlashNeoxAttention(torch.nn.Module):
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
if cu_seqlen_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
@ -246,8 +243,7 @@ class FlashNeoXLayer(nn.Module):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -261,8 +257,7 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -286,8 +281,7 @@ class FlashNeoXLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -341,8 +335,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -364,8 +357,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
@ -391,8 +383,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -403,8 +394,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states = self.gpt_neox(
input_ids,
position_ids,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,

View File

@ -144,8 +144,7 @@ class FlashRWAttention(torch.nn.Module):
hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -176,7 +175,7 @@ class FlashRWAttention(torch.nn.Module):
attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
@ -187,10 +186,8 @@ class FlashRWAttention(torch.nn.Module):
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
@ -276,8 +273,7 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -311,7 +307,7 @@ class FlashRWLargeAttention(torch.nn.Module):
attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
@ -325,10 +321,8 @@ class FlashRWLargeAttention(torch.nn.Module):
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
@ -428,8 +422,7 @@ class FlashRWLayer(nn.Module):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -443,8 +436,7 @@ class FlashRWLayer(nn.Module):
ln_hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -466,8 +458,7 @@ class FlashRWLayer(nn.Module):
hidden_states,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -516,8 +507,7 @@ class FlashRWLargeLayer(nn.Module):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -532,8 +522,7 @@ class FlashRWLargeLayer(nn.Module):
ln_attn,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -597,8 +586,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -620,8 +608,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual,
cos,
sin,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
@ -648,8 +635,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -660,8 +646,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,

View File

@ -232,8 +232,7 @@ class FlashMQAttention(torch.nn.Module):
def forward(
self,
hidden_states,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -259,7 +258,7 @@ class FlashMQAttention(torch.nn.Module):
attn_output = torch.empty_like(query)
# Prefill
if start_seq_prefill is not None:
if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
@ -269,10 +268,8 @@ class FlashMQAttention(torch.nn.Module):
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
cu_seqlen_prefill,
max_s,
max_s,
0.0,
@ -357,8 +354,7 @@ class Block(nn.Module):
self,
hidden_states,
residual,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -369,8 +365,7 @@ class Block(nn.Module):
hidden_states = self.attn(
hidden_states,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
@ -423,8 +418,7 @@ class FlashSantacoderModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -441,8 +435,7 @@ class FlashSantacoderModel(nn.Module):
hidden_states, residual = layer(
hidden_states,
residual,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
@ -467,8 +460,7 @@ class FlashSantacoderForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
@ -479,8 +471,7 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states = self.transformer(
input_ids,
position_ids,
start_seq_prefill,
end_seq_prefill,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,

View File

@ -121,10 +121,10 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor
position_ids: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Paged Attention values
@ -197,8 +197,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
start_seq_prefill = []
end_seq_prefill = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
@ -250,8 +249,7 @@ class FlashCausalLMBatch(Batch):
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
start_seq_prefill.append(cumulative_length)
end_seq_prefill.append(cumulative_length + input_length)
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
@ -329,11 +327,8 @@ class FlashCausalLMBatch(Batch):
position_ids = position_ids[0]
slot_indices = slot_indices[0]
start_seq_prefill = torch.tensor(
start_seq_prefill, device=device, dtype=torch.int32
)
end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
@ -345,9 +340,9 @@ class FlashCausalLMBatch(Batch):
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = end_seq_prefill - 1
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = end_seq_prefill - 1
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
@ -363,8 +358,7 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
cu_seqlen_prefill=cu_seqlen_prefill,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
@ -504,8 +498,7 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
start_seq_prefill=None,
end_seq_prefill=None,
cu_seqlen_prefill=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
@ -652,8 +645,7 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
start_seq_prefill=None,
end_seq_prefill=None,
cu_seqlen_prefill=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
@ -750,8 +742,7 @@ class FlashCausalLM(Model):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
@ -764,8 +755,7 @@ class FlashCausalLM(Model):
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=CACHE_MANAGER.kv_cache,
block_tables=block_tables,
slots=slots,
@ -778,7 +768,7 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.start_seq_prefill is not None
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
if batch.needed_blocks_slots:
@ -788,8 +778,7 @@ class FlashCausalLM(Model):
out = self.forward(
batch.input_ids,
batch.position_ids,
batch.start_seq_prefill,
batch.end_seq_prefill,
batch.cu_seqlen_prefill,
batch.block_tables_tensor,
batch.slots[batch.slot_indices],
batch.input_lengths_tensor,
@ -815,10 +804,9 @@ class FlashCausalLM(Model):
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
# We do not need start_seq_prefill and end_seq_prefill anymore
batch.start_seq_prefill = None
batch.end_seq_prefill = None
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
else:
prefill_logprobs = None
next_position_ids = batch.position_ids

View File

@ -66,7 +66,9 @@ class MPTSharded(CausalLM):
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)

View File

@ -359,7 +359,7 @@ try:
def __init__(self, inv_freq):
super().__init__()
self.register_buffer("inv_freq", inv_freq)
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None