This commit is contained in:
OlivierDehaene 2023-12-11 16:46:44 +01:00
parent d0841cc8eb
commit ec6d4592d5
7 changed files with 17 additions and 11 deletions

8
Cargo.lock generated
View File

@ -2754,7 +2754,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "1.3.0"
version = "1.3.1"
dependencies = [
"average",
"clap",
@ -2775,7 +2775,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "1.3.0"
version = "1.3.1"
dependencies = [
"futures",
"grpc-metadata",
@ -2791,7 +2791,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "1.3.0"
version = "1.3.1"
dependencies = [
"clap",
"ctrlc",
@ -2807,7 +2807,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "1.3.0"
version = "1.3.1"
dependencies = [
"async-stream",
"axum",

View File

@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "1.3.0"
version = "1.3.1"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "1.3.0"
"version": "1.3.1"
},
"paths": {
"/": {

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-integration-tests"
version = "1.3.0"
version = "1.3.1"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "1.3.0"
version = "1.3.1"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -391,6 +391,7 @@ class MistralModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -398,7 +399,7 @@ class MistralModel(torch.nn.Module):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
position_ids, true_max_s, hidden_states.dtype
)
residual = None
@ -449,6 +450,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
@ -467,6 +469,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:

View File

@ -401,7 +401,7 @@ class BlockSparseMoE(nn.Module):
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[:block_rows]
offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
@ -632,6 +632,7 @@ class MixtralModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -639,7 +640,7 @@ class MixtralModel(torch.nn.Module):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
position_ids, true_max_s, hidden_states.dtype
)
residual = None
@ -690,6 +691,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
@ -708,6 +710,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None: