Initialize v_cache to avoid NaNs (#11)
This commit is contained in:
parent
1e646fb41d
commit
2fda8fe812
|
@ -44,6 +44,8 @@ class CacheManager:
|
|||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
# v_cache can be initialized with any finite value, because the initial value would not take effect unless it is infinite or NaN. We use the max positive value to expose any portential bugs if the value is accidentally used in any computation.
|
||||
self.v_cache_initial_value = torch.finfo(dtype).max
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
|
@ -110,6 +112,10 @@ class CacheManager:
|
|||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
# Initialize v_cache to avoid NaNs
|
||||
# See discussion at https://github.com/vllm-project/vllm/issues/641#issuecomment-1682619534
|
||||
for _k_cache, v_cache in self.kv_cache:
|
||||
v_cache[block_indices] = self.v_cache_initial_value
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
|
|
Loading…
Reference in New Issue