feat(server): update vllm version (#723)

This commit is contained in:
OlivierDehaene 2023-07-28 15:36:38 +02:00 committed by GitHub
parent f848decee6
commit afd04dc71e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 22 deletions

View File

@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> {
"Inferred max batch total tokens: {max_supported_batch_total_tokens}" "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
); );
} }
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
}
max_supported_batch_total_tokens max_supported_batch_total_tokens
} }
}; };

View File

@ -1,4 +1,4 @@
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9 vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a
vllm: vllm:
# Clone vllm # Clone vllm

View File

@ -219,36 +219,31 @@ class TensorParallelHead(SuperLayer):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size() world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear): # Fast branch for single requests
if (
self.should_gather
and len(input.shape) == 2
and isinstance(self.linear, FastLinear)
and input.shape[0] == 1
):
out_dim = self.linear.weight.shape[0] out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size) world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim) local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group world_out, local_out, group=self.process_group
) )
if input.shape[0] == 1:
return world_out return world_out
return world_out.T
output = super().forward(input) output = super().forward(input)
world_output = [ if not self.should_gather:
torch.empty_like(output) for _ in range(self.process_group.size()) return output
]
world_output = [torch.empty_like(output) for _ in range(world_size)]
torch.distributed.all_gather(world_output, output, group=self.process_group) torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1) world_output = torch.cat(world_output, dim=-1)
return world_output return world_output