feat(server): update vllm version (#723)
This commit is contained in:
parent
f848decee6
commit
afd04dc71e
|
@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> {
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
max_supported_batch_total_tokens
|
max_supported_batch_total_tokens
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
|
vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
|
|
|
@ -219,36 +219,31 @@ class TensorParallelHead(SuperLayer):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if not self.should_gather:
|
|
||||||
return super().forward(input)
|
|
||||||
|
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
# Fast branch for single requests
|
||||||
|
if (
|
||||||
|
self.should_gather
|
||||||
|
and len(input.shape) == 2
|
||||||
|
and isinstance(self.linear, FastLinear)
|
||||||
|
and input.shape[0] == 1
|
||||||
|
):
|
||||||
out_dim = self.linear.weight.shape[0]
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
|
||||||
world_out = input.new_empty(1, out_dim * world_size)
|
world_out = input.new_empty(1, out_dim * world_size)
|
||||||
local_out = input.new_empty(1, out_dim)
|
local_out = input.new_empty(1, out_dim)
|
||||||
gather_input = local_out
|
|
||||||
else:
|
|
||||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
|
||||||
gather_input = input.new_empty(out_dim, input.shape[0])
|
|
||||||
local_out = gather_input.T
|
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, local_out, group=self.process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
|
||||||
return world_out
|
return world_out
|
||||||
return world_out.T
|
|
||||||
|
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
world_output = [
|
if not self.should_gather:
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
return output
|
||||||
]
|
|
||||||
|
world_output = [torch.empty_like(output) for _ in range(world_size)]
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
|
|
Loading…
Reference in New Issue