From afd04dc71e1ba83727975080460277d80f975f1e Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 28 Jul 2023 15:36:38 +0200 Subject: [PATCH] feat(server): update vllm version (#723) --- router/src/main.rs | 6 +++- server/Makefile-vllm | 2 +- server/text_generation_server/utils/layers.py | 35 ++++++++----------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 059f8692..484643cb 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> { "Inferred max batch total tokens: {max_supported_batch_total_tokens}" ); } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"))); + } + max_supported_batch_total_tokens } }; @@ -270,7 +274,7 @@ fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, ) - .await?; + .await?; Ok(()) }) } diff --git a/server/Makefile-vllm b/server/Makefile-vllm index af750733..9100fff4 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,4 +1,4 @@ -vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9 +vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a vllm: # Clone vllm diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 183cf2c1..7a45808e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -219,36 +219,31 @@ class TensorParallelHead(SuperLayer): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - if not self.should_gather: - return super().forward(input) - world_size = self.process_group.size() - if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + # Fast branch for single requests + if ( + self.should_gather + and len(input.shape) == 2 + and isinstance(self.linear, FastLinear) + and input.shape[0] == 1 + ): out_dim = self.linear.weight.shape[0] - if input.shape[0] == 1: - world_out = input.new_empty(1, out_dim * world_size) - local_out = input.new_empty(1, out_dim) - gather_input = local_out - else: - world_out = input.new_empty(out_dim * world_size, input.shape[0]) - gather_input = input.new_empty(out_dim, input.shape[0]) - local_out = gather_input.T + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) torch.mm(input, self.linear.weight.T, out=local_out) torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group + world_out, local_out, group=self.process_group ) - - if input.shape[0] == 1: - return world_out - return world_out.T + return world_out output = super().forward(input) - world_output = [ - torch.empty_like(output) for _ in range(self.process_group.size()) - ] + if not self.should_gather: + return output + + world_output = [torch.empty_like(output) for _ in range(world_size)] torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output