fix(server): use mem_get_info to get kv cache size (#664)
Close https://github.com/huggingface/text-generation-inference/issues/649 Close https://github.com/huggingface/text-generation-inference/issues/651 Close https://github.com/huggingface/text-generation-inference/issues/653 Close #636
This commit is contained in:
parent
08b8eec1d7
commit
bf94df3c71
|
@ -154,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize != "gptq":
|
||||
|
@ -168,7 +168,9 @@ def _load_gqa(config, prefix: str, weights):
|
|||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
|
|
|
@ -39,6 +39,7 @@ class CacheManager:
|
|||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
@ -714,7 +715,6 @@ class FlashCausalLM(Model):
|
|||
global CACHE_MANAGER
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
try:
|
||||
CACHE_MANAGER = CacheManager(
|
||||
batch.blocks,
|
||||
|
@ -731,23 +731,20 @@ class FlashCausalLM(Model):
|
|||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
) from e
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize(self.device)
|
||||
peak_memory = torch.cuda.max_memory_reserved(self.device)
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
|
||||
# 0.98 to add some wiggle room
|
||||
num_blocks = (
|
||||
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
|
||||
int(free_memory // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ batch.blocks
|
||||
+ CACHE_MANAGER.num_blocks
|
||||
)
|
||||
|
||||
del CACHE_MANAGER
|
||||
|
|
|
@ -867,8 +867,9 @@ def quantize(
|
|||
)
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16,
|
||||
trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model = model.eval()
|
||||
|
||||
print("LOADED model")
|
||||
|
|
Loading…
Reference in New Issue