Fixing rocm. (#2164)

This commit is contained in:
Nicolas Patry 2024-07-02 12:01:08 +02:00 committed by GitHub
parent b966bc0d35
commit dea9c0dc74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import os
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers.attention import Seqlen
from loguru import logger
major, minor = torch.cuda.get_device_capability()
@ -45,8 +46,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
input_lengths: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -70,7 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
input_lengths = input_lengths.input_lengths
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use