OlivierDehaene 2023-07-20 17:23:49 +02:00 committed by GitHub
parent 08b8eec1d7
commit bf94df3c71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 13 deletions

View File

@ -154,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize, quantize=config.quantize,
dim=0 dim=0,
) )
if config.quantize != "gptq": if config.quantize != "gptq":
@ -168,7 +168,9 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):

View File

@ -39,6 +39,7 @@ class CacheManager:
device: torch.device, device: torch.device,
): ):
self.block_size = BLOCK_SIZE self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size x = self.block_size // element_size
@ -714,7 +715,6 @@ class FlashCausalLM(Model):
global CACHE_MANAGER global CACHE_MANAGER
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.device)
try: try:
CACHE_MANAGER = CacheManager( CACHE_MANAGER = CacheManager(
batch.blocks, batch.blocks,
@ -731,23 +731,20 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize(self.device) torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_reserved(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size() dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory free_memory, _ = torch.cuda.mem_get_info(self.device)
# 0.98 to add some wiggle room
num_blocks = ( num_blocks = (
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size) int(free_memory // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory. # Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks + CACHE_MANAGER.num_blocks
) )
del CACHE_MANAGER del CACHE_MANAGER

View File

@ -867,8 +867,9 @@ def quantize(
) )
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16, model = AutoModelForCausalLM.from_config(
trust_remote_code=trust_remote_code) config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
model = model.eval() model = model.eval()
print("LOADED model") print("LOADED model")