Fixing rocm. (#2164)
This commit is contained in:
parent
b966bc0d35
commit
dea9c0dc74
|
@ -2,6 +2,7 @@ import os
|
|||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
|
@ -45,8 +46,7 @@ def paged_attention(
|
|||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlen_q: torch.Tensor,
|
||||
cu_seqlen_k: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
|
@ -70,7 +70,7 @@ def paged_attention(
|
|||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = cu_seqlen_k
|
||||
input_lengths = input_lengths.input_lengths
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
|
|
Loading…
Reference in New Issue