fix(server): use mem_get_info to get kv cache size (#664)
Close https://github.com/huggingface/text-generation-inference/issues/649 Close https://github.com/huggingface/text-generation-inference/issues/651 Close https://github.com/huggingface/text-generation-inference/issues/653 Close #636
This commit is contained in:
parent
08b8eec1d7
commit
bf94df3c71
|
@ -154,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
quantize=config.quantize,
|
quantize=config.quantize,
|
||||||
dim=0
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize != "gptq":
|
if config.quantize != "gptq":
|
||||||
|
@ -168,7 +168,9 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
|
return TensorParallelColumnLinear(
|
||||||
|
get_linear(weight, bias=None, quantize=config.quantize)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
|
@ -39,6 +39,7 @@ class CacheManager:
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.block_size = BLOCK_SIZE
|
self.block_size = BLOCK_SIZE
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
x = self.block_size // element_size
|
x = self.block_size // element_size
|
||||||
|
@ -714,7 +715,6 @@ class FlashCausalLM(Model):
|
||||||
global CACHE_MANAGER
|
global CACHE_MANAGER
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats(self.device)
|
|
||||||
try:
|
try:
|
||||||
CACHE_MANAGER = CacheManager(
|
CACHE_MANAGER = CacheManager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
|
@ -731,23 +731,20 @@ class FlashCausalLM(Model):
|
||||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
|
||||||
# Calculate the number of blocks that can be allocated with the
|
|
||||||
# profiled peak memory.
|
|
||||||
torch.cuda.synchronize(self.device)
|
torch.cuda.synchronize(self.device)
|
||||||
peak_memory = torch.cuda.max_memory_reserved(self.device)
|
|
||||||
|
|
||||||
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||||
|
|
||||||
# 0.98 to add some wiggle room
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
|
int(free_memory // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ batch.blocks
|
+ CACHE_MANAGER.num_blocks
|
||||||
)
|
)
|
||||||
|
|
||||||
del CACHE_MANAGER
|
del CACHE_MANAGER
|
||||||
|
|
|
@ -867,8 +867,9 @@ def quantize(
|
||||||
)
|
)
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16,
|
model = AutoModelForCausalLM.from_config(
|
||||||
trust_remote_code=trust_remote_code)
|
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
print("LOADED model")
|
print("LOADED model")
|
||||||
|
|
Loading…
Reference in New Issue