Initialize v_cache to avoid NaNs (#11)

This commit is contained in:
Yang, Bo 2023-08-23 14:07:06 -07:00 committed by GitHub
parent 1e646fb41d
commit 2fda8fe812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -44,6 +44,8 @@ class CacheManager:
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
# v_cache can be initialized with any finite value, because the initial value would not take effect unless it is infinite or NaN. We use the max positive value to expose any portential bugs if the value is accidentally used in any computation.
self.v_cache_initial_value = torch.finfo(dtype).max
self.kv_cache = [
(
torch.empty(
@ -110,6 +112,10 @@ class CacheManager:
# Reset mask
self.free_block_mask[block_indices] = 1
# Initialize v_cache to avoid NaNs
# See discussion at https://github.com/vllm-project/vllm/issues/641#issuecomment-1682619534
for _k_cache, v_cache in self.kv_cache:
v_cache[block_indices] = self.v_cache_initial_value
@dataclass
class FlashCausalLMBatch(Batch):