This commit is contained in:
OlivierDehaene 2023-12-11 16:46:44 +01:00
parent d0841cc8eb
commit ec6d4592d5
7 changed files with 17 additions and 11 deletions

8
Cargo.lock generated
View File

@ -2754,7 +2754,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "1.3.0" version = "1.3.1"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2775,7 +2775,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "1.3.0" version = "1.3.1"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2791,7 +2791,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "1.3.0" version = "1.3.1"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2807,7 +2807,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "1.3.0" version = "1.3.1"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",

View File

@ -8,7 +8,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "1.3.0" version = "1.3.1"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.3.0" "version": "1.3.1"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "1.3.0" version = "1.3.1"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "1.3.0" version = "1.3.1"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -391,6 +391,7 @@ class MistralModel(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -398,7 +399,7 @@ class MistralModel(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype position_ids, true_max_s, hidden_states.dtype
) )
residual = None residual = None
@ -449,6 +450,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices] slots = slots[prefill_cache_indices]
@ -467,6 +469,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
true_max_s,
prefill_cache_indices, prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -401,7 +401,7 @@ class BlockSparseMoE(nn.Module):
self.offsets_block_rows = block_rows self.offsets_block_rows = block_rows
offsets = self.offsets offsets = self.offsets
else: else:
offsets = self.offsets[:block_rows] offsets = self.offsets[: block_rows + 1]
# Indices for the sparse matrix. The indices for # Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending # the intermediate matrix are dynamic depending
@ -632,6 +632,7 @@ class MixtralModel(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -639,7 +640,7 @@ class MixtralModel(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype position_ids, true_max_s, hidden_states.dtype
) )
residual = None residual = None
@ -690,6 +691,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices] slots = slots[prefill_cache_indices]
@ -708,6 +710,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
true_max_s,
prefill_cache_indices, prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None: