feat: add triton kernels to decrease latency of large batches (#2687)
* feat: add triton kernels to decrease latency of large batches * cast to int32 * fix kernel * fix kernel * disable triton on rocm * fix speculation * add slots filtering kernel
This commit is contained in:
parent
0f346a3296
commit
6f88bd9390
|
@ -71,6 +71,14 @@ from text_generation_server.utils.import_utils import (
|
||||||
synchronize,
|
synchronize,
|
||||||
get_free_memory,
|
get_free_memory,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.metadata_kernels import (
|
||||||
|
has_triton,
|
||||||
|
copy_next_input_ids_inplace,
|
||||||
|
block_tables_to_ragged,
|
||||||
|
block_tables_to_padded,
|
||||||
|
prepare_position_slot_ids,
|
||||||
|
slots_filtering,
|
||||||
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -147,8 +155,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
slots: torch.Tensor
|
||||||
slots: Optional[torch.Tensor]
|
# list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
|
||||||
|
# used for filtering
|
||||||
|
cu_slots: torch.Tensor
|
||||||
|
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_current_length: int
|
max_current_length: int
|
||||||
|
@ -159,7 +169,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefilling_mask: List[bool]
|
prefilling_mask: List[bool]
|
||||||
|
|
||||||
# Prefill metadata tensors to efficiently compute logprobs
|
# Prefill metadata tensors to efficiently compute logprobs
|
||||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
# tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||||
|
@ -257,6 +267,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_postfix_ids = []
|
all_postfix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
slots = []
|
||||||
|
cu_slots = [0]
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
@ -268,7 +280,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_length = 0
|
max_length = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
|
||||||
|
cu_blocks = [0]
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
block_tables_ragged = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
|
@ -341,10 +355,21 @@ class FlashCausalLMBatch(Batch):
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
]
|
]
|
||||||
|
request_slots = [
|
||||||
|
s
|
||||||
|
for b in request_blocks
|
||||||
|
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
|
request_slots = r.slots
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
|
block_tables_ragged.extend(request_blocks)
|
||||||
|
cu_blocks.append(len(block_tables_ragged))
|
||||||
|
|
||||||
|
slots.extend(request_slots)
|
||||||
|
cu_slots.append(len(slots))
|
||||||
|
|
||||||
cache_lengths.append(cache_length)
|
cache_lengths.append(cache_length)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
|
@ -378,16 +403,34 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_ragged = torch.tensor(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
block_tables_ragged, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
for i, request_blocks in enumerate(block_tables):
|
cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor = torch.empty(
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
(len(block_tables), max_blocks),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the device supports Triton, we can use a fused kernel
|
||||||
|
if has_triton():
|
||||||
|
block_tables_to_padded(
|
||||||
|
max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for i, request_blocks in enumerate(block_tables):
|
||||||
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(
|
||||||
|
request_blocks
|
||||||
|
)
|
||||||
|
|
||||||
prompt_lengths_tensor = torch.tensor(
|
prompt_lengths_tensor = torch.tensor(
|
||||||
prompt_lengths, dtype=torch.int32, device=device
|
prompt_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
|
@ -420,7 +463,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
slot_indices=None,
|
slot_indices=None,
|
||||||
slots=None,
|
slots=slots,
|
||||||
|
cu_slots=cu_slots,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
|
@ -457,10 +501,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
# slots to keep after filtering
|
if not has_triton():
|
||||||
slot_filtering_indices = torch.zeros(
|
# slots to keep after filtering
|
||||||
self.slots.shape[0], dtype=torch.bool, device=device
|
slot_filtering_indices = torch.zeros(
|
||||||
)
|
self.slots.shape[0], dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||||
|
@ -477,6 +522,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
cache_lengths = []
|
cache_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
cu_slots = [0]
|
||||||
|
|
||||||
prefilling_mask = []
|
prefilling_mask = []
|
||||||
prefill_logprob_tokens = []
|
prefill_logprob_tokens = []
|
||||||
|
@ -487,8 +533,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
# Cumulative length
|
max_slots = 0
|
||||||
cumulative_max_length = 0
|
cumulative_slot_tokens = 0
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
@ -531,29 +577,27 @@ class FlashCausalLMBatch(Batch):
|
||||||
num_blocks += len(request_block_table)
|
num_blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
|
|
||||||
|
start_slot = self.cu_slots[idx]
|
||||||
|
end_slot = self.cu_slots[idx + 1]
|
||||||
|
slot_length = end_slot - start_slot
|
||||||
|
|
||||||
|
if not has_triton():
|
||||||
|
# Set slice
|
||||||
|
slot_filtering_indices[start_slot:end_slot] = True
|
||||||
|
|
||||||
|
cu_slots.append(cumulative_slot_tokens + slot_length)
|
||||||
|
|
||||||
# Input ids if the request was part of a prefilling batch
|
# Input ids if the request was part of a prefilling batch
|
||||||
# If the batch was decoding we can index into the tensor directly later
|
# If the batch was decoding we can index into the tensor directly later
|
||||||
if self.prefilling:
|
if self.prefilling:
|
||||||
input_ids.append(self.input_ids[idx])
|
input_ids.append(self.input_ids[idx])
|
||||||
else:
|
else:
|
||||||
# Copy to tensor (CPU)
|
# Copy to tensor (CPU)
|
||||||
slot_indices[i] = cumulative_max_length
|
slot_indices[i] = cumulative_slot_tokens + request_cache_length
|
||||||
|
|
||||||
remaining_tokens = (
|
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set slice
|
|
||||||
slot_filtering_indices[
|
|
||||||
self.slot_indices[idx] : self.slot_indices[idx]
|
|
||||||
+ request_input_length
|
|
||||||
+ remaining_tokens
|
|
||||||
- 1
|
|
||||||
] = True
|
|
||||||
|
|
||||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
|
||||||
|
|
||||||
|
cumulative_slot_tokens += slot_length
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
max_slots = max(max_slots, slot_length)
|
||||||
|
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
|
@ -564,11 +608,22 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||||
|
|
||||||
|
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
if not has_triton():
|
||||||
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
else:
|
||||||
|
slots = self.slots.new_empty(cumulative_slot_tokens)
|
||||||
|
gpu_cu_slots = cu_slots.to(device)
|
||||||
|
slots_indexing_start = self.cu_slots.to(device)[indices]
|
||||||
|
slots_filtering(
|
||||||
|
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
|
||||||
|
)
|
||||||
|
|
||||||
if self.prefilling:
|
if self.prefilling:
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids = None
|
position_ids = None
|
||||||
slot_indices = None
|
slot_indices = None
|
||||||
slots = None
|
|
||||||
cache_lengths_tensor = None
|
cache_lengths_tensor = None
|
||||||
input_lengths_tensor = None
|
input_lengths_tensor = None
|
||||||
adapter_meta = None
|
adapter_meta = None
|
||||||
|
@ -578,7 +633,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
|
||||||
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
|
@ -607,6 +661,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
cu_slots=cu_slots,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_current_length=max_current_length,
|
max_current_length=max_current_length,
|
||||||
prefilling=self.prefilling,
|
prefilling=self.prefilling,
|
||||||
|
@ -653,9 +708,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
# If `b` is prefilling and was just filtered, `b.slots` is None
|
total_slots += len(b.slots)
|
||||||
# `total_slots` is not used if any of the batches is prefilling
|
|
||||||
total_slots += len(b.slots) if not b.prefilling else 0
|
|
||||||
num_blocks += b.num_blocks
|
num_blocks += b.num_blocks
|
||||||
speculative_length = (
|
speculative_length = (
|
||||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
|
@ -675,11 +728,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
prefilling = prefilling or b.prefilling
|
prefilling = prefilling or b.prefilling
|
||||||
|
|
||||||
|
slots = batches[0].slots.new_empty(total_slots)
|
||||||
|
cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
|
||||||
if prefilling:
|
if prefilling:
|
||||||
input_ids = []
|
input_ids = []
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids = None
|
position_ids = None
|
||||||
slots = None
|
|
||||||
slot_indices = None
|
slot_indices = None
|
||||||
cache_lengths_tensor = None
|
cache_lengths_tensor = None
|
||||||
input_lengths_tensor = None
|
input_lengths_tensor = None
|
||||||
|
@ -688,7 +742,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
else:
|
else:
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
slots = batches[0].slots.new_empty(total_slots)
|
|
||||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||||
total_batch_size
|
total_batch_size
|
||||||
|
@ -764,13 +817,16 @@ class FlashCausalLMBatch(Batch):
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||||
|
|
||||||
if not prefilling:
|
slots_start_index = cumulative_slots
|
||||||
slots_start_index = cumulative_slots
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
slots_end_index = cumulative_slots + len(batch.slots)
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
cu_slots[start_index + 1 : end_index + 1] = (
|
||||||
|
batch.cu_slots[1:] + cumulative_slots
|
||||||
|
)
|
||||||
|
|
||||||
|
if not prefilling:
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
|
||||||
slot_indices[start_index:end_index] = (
|
slot_indices[start_index:end_index] = (
|
||||||
batch.slot_indices + cumulative_slots
|
batch.slot_indices + cumulative_slots
|
||||||
)
|
)
|
||||||
|
@ -792,9 +848,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
batch.adapter_meta.adapter_segments,
|
batch.adapter_meta.adapter_segments,
|
||||||
batch.adapter_meta.segment_indices,
|
batch.adapter_meta.segment_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update
|
|
||||||
cumulative_slots += len(batch.slots)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(batch.input_ids, torch.Tensor):
|
if isinstance(batch.input_ids, torch.Tensor):
|
||||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||||
|
@ -819,6 +872,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens.extend(batch.top_n_tokens)
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
|
cumulative_slots += len(batch.slots)
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
@ -858,6 +912,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
cache_lengths=cache_lengths,
|
cache_lengths=cache_lengths,
|
||||||
cache_lengths_tensor=cache_lengths_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
cu_slots=cu_slots,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_current_length=max_current_length,
|
max_current_length=max_current_length,
|
||||||
prefilling=prefilling,
|
prefilling=prefilling,
|
||||||
|
@ -890,15 +945,50 @@ class FlashCausalLMBatch(Batch):
|
||||||
# it simplifies everything
|
# it simplifies everything
|
||||||
assert self.speculative_ids is None
|
assert self.speculative_ids is None
|
||||||
|
|
||||||
|
device = self.block_tables_tensor.device
|
||||||
|
|
||||||
|
if isinstance(self.input_ids, list):
|
||||||
|
if len(self) > 1:
|
||||||
|
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
input_ids = self.input_ids[0]
|
||||||
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
self.input_lengths_tensor = torch.tensor(
|
||||||
|
self.input_lengths, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
self.cu_seqlen_prefill = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
|
||||||
|
).to(torch.int32)
|
||||||
|
self.cache_lengths_tensor = torch.tensor(
|
||||||
|
self.cache_lengths, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the device supports Triton, we can use a fused kernel
|
||||||
|
if has_triton():
|
||||||
|
self.position_ids = torch.empty(
|
||||||
|
len(self.input_ids), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
self.slot_indices = torch.empty(
|
||||||
|
len(self.input_ids), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
cu_slots_gpu = self.cu_slots.to(device)
|
||||||
|
|
||||||
|
prepare_position_slot_ids(
|
||||||
|
self.max_input_length,
|
||||||
|
self.cache_lengths_tensor,
|
||||||
|
self.cu_seqlen_prefill,
|
||||||
|
cu_slots_gpu,
|
||||||
|
self.position_ids,
|
||||||
|
self.slot_indices,
|
||||||
|
)
|
||||||
|
|
||||||
sliding_window = get_sliding_windows()
|
sliding_window = get_sliding_windows()
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
prefill_cache_indices = []
|
prefill_cache_indices = []
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
no_prefill_logprobs = True
|
no_prefill_logprobs = True
|
||||||
prefill_head_indices = []
|
|
||||||
prefill_next_token_indices = []
|
|
||||||
prefill_cu_outlens = [0]
|
prefill_cu_outlens = [0]
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
|
@ -906,7 +996,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
cumulative_slot_tokens = 0
|
cumulative_slot_tokens = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
slots = []
|
|
||||||
adapter_indices_list = []
|
adapter_indices_list = []
|
||||||
adapter_set = set()
|
adapter_set = set()
|
||||||
|
|
||||||
|
@ -928,30 +1017,33 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
next_chunk_length = input_length
|
next_chunk_length = input_length
|
||||||
# Position ids
|
|
||||||
request_position_ids = torch.arange(
|
|
||||||
cache_length, cache_length + input_length, dtype=torch.int32
|
|
||||||
)
|
|
||||||
position_ids.append(request_position_ids)
|
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
if not has_triton():
|
||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
# Position ids
|
||||||
|
request_position_ids = torch.arange(
|
||||||
|
cache_length, cache_length + input_length, dtype=torch.int32
|
||||||
|
)
|
||||||
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
if not r.slots:
|
if not r.slots:
|
||||||
request_slots = [
|
request_slots = [
|
||||||
s
|
s
|
||||||
for b in blocks
|
for b in blocks
|
||||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
request_slots = r.slots
|
request_slots = r.slots
|
||||||
|
|
||||||
request_slots = request_slots[cache_length:]
|
request_slot_indices = torch.arange(
|
||||||
request_slot_indices = torch.arange(
|
cache_length + cumulative_slot_tokens,
|
||||||
cumulative_slot_tokens,
|
cache_length + cumulative_slot_tokens + input_length,
|
||||||
cumulative_slot_tokens + input_length,
|
dtype=torch.int64,
|
||||||
dtype=torch.int64,
|
)
|
||||||
)
|
|
||||||
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
|
# Update
|
||||||
|
cumulative_slot_tokens += len(request_slots)
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
|
@ -968,83 +1060,102 @@ class FlashCausalLMBatch(Batch):
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||||
|
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
prefill_head_indices.append(
|
|
||||||
torch.arange(
|
|
||||||
cumulative_length,
|
|
||||||
cumulative_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prefill_next_token_indices.append(
|
|
||||||
prefill_out_cumulative_length + input_length - 1
|
|
||||||
)
|
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||||
prefill_out_cumulative_length += input_length
|
prefill_out_cumulative_length += input_length
|
||||||
else:
|
else:
|
||||||
prefill_head_indices.append(
|
|
||||||
torch.tensor(
|
|
||||||
[cumulative_length + input_length - 1],
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
slots.extend(request_slots)
|
|
||||||
slot_indices.append(request_slot_indices)
|
|
||||||
|
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||||
|
|
||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
if ADAPTER_TO_INDEX:
|
||||||
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
|
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||||
adapter_set.add(adapter_index)
|
adapter_indices_list.append(
|
||||||
|
torch.full((next_chunk_length,), adapter_index)
|
||||||
|
)
|
||||||
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += next_chunk_length
|
cumulative_length += next_chunk_length
|
||||||
cumulative_slot_tokens += len(request_slots)
|
|
||||||
|
|
||||||
device = self.block_tables_tensor.device
|
if not all_prefill_logprobs and not no_prefill_logprobs:
|
||||||
|
prefill_head_indices = []
|
||||||
|
prefill_next_token_indices = []
|
||||||
|
|
||||||
if isinstance(self.input_ids, list):
|
# Cumulative length
|
||||||
if len(self) > 1:
|
cumulative_length = 0
|
||||||
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
prefill_out_cumulative_length = 0
|
||||||
else:
|
|
||||||
input_ids = self.input_ids[0]
|
for i, (
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
r,
|
||||||
|
input_length,
|
||||||
|
request_prefilling,
|
||||||
|
) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.requests,
|
||||||
|
self.input_lengths,
|
||||||
|
self.prefilling_mask,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# Prefill logprobs is ignored if the request is done prefilling
|
||||||
|
prefill_logprobs = r.prefill_logprobs and request_prefilling
|
||||||
|
|
||||||
|
if prefill_logprobs:
|
||||||
|
prefill_head_indices.append(
|
||||||
|
torch.arange(
|
||||||
|
cumulative_length,
|
||||||
|
cumulative_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefill_next_token_indices.append(
|
||||||
|
prefill_out_cumulative_length + input_length - 1
|
||||||
|
)
|
||||||
|
prefill_out_cumulative_length += input_length
|
||||||
|
else:
|
||||||
|
prefill_head_indices.append(
|
||||||
|
torch.tensor(
|
||||||
|
[cumulative_length + input_length - 1],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||||
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
|
# Update
|
||||||
|
cumulative_length += input_length
|
||||||
|
|
||||||
if len(self) > 1:
|
if len(self) > 1:
|
||||||
position_ids = torch.cat(position_ids)
|
if position_ids:
|
||||||
slot_indices = torch.cat(slot_indices)
|
position_ids = torch.cat(position_ids)
|
||||||
|
if slot_indices:
|
||||||
|
slot_indices = torch.cat(slot_indices)
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids[0]
|
if position_ids:
|
||||||
slot_indices = slot_indices[0]
|
position_ids = position_ids[0]
|
||||||
|
if slot_indices:
|
||||||
|
slot_indices = slot_indices[0]
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
prefill_cache_indices = prefill_cache_indices[0]
|
||||||
|
|
||||||
|
if not has_triton():
|
||||||
|
self.position_ids = position_ids.to(device)
|
||||||
|
self.slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
self.prefill_cu_outlens = prefill_cu_outlens
|
self.prefill_cu_outlens = prefill_cu_outlens
|
||||||
cu_seqlen_prefill = torch.tensor(
|
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
self.cu_seqlen_prefill = cu_seqlen_prefill
|
|
||||||
self.position_ids = position_ids.to(device)
|
|
||||||
self.slot_indices = slot_indices.to(device)
|
|
||||||
self.prefill_cache_indices = (
|
self.prefill_cache_indices = (
|
||||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||||
)
|
)
|
||||||
self.input_lengths_tensor = torch.tensor(
|
|
||||||
self.input_lengths, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
|
||||||
elif no_prefill_logprobs:
|
elif no_prefill_logprobs:
|
||||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
|
||||||
prefill_next_token_indices = None
|
prefill_next_token_indices = None
|
||||||
else:
|
else:
|
||||||
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
|
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
|
||||||
|
@ -1054,17 +1165,21 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
self.prefill_head_indices = prefill_head_indices
|
self.prefill_head_indices = prefill_head_indices
|
||||||
self.prefill_next_token_indices = prefill_next_token_indices
|
self.prefill_next_token_indices = prefill_next_token_indices
|
||||||
self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
|
||||||
self.cache_lengths_tensor = torch.tensor(
|
if adapter_set:
|
||||||
self.cache_lengths, dtype=torch.int32, device=device
|
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||||
)
|
dtype=torch.int64, device=device
|
||||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
)
|
||||||
dtype=torch.int64, device=device
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
)
|
else:
|
||||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
adapter_indices = torch.zeros_like(self.input_ids)
|
||||||
|
adapter_segments = [0, len(adapter_indices)]
|
||||||
|
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||||
|
|
||||||
adapter_segments = torch.tensor(
|
adapter_segments = torch.tensor(
|
||||||
adapter_segments, dtype=torch.int32, device=device
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.adapter_meta = AdapterBatchMetadata(
|
self.adapter_meta = AdapterBatchMetadata(
|
||||||
adapter_indices=adapter_indices,
|
adapter_indices=adapter_indices,
|
||||||
adapter_set=adapter_set,
|
adapter_set=adapter_set,
|
||||||
|
@ -1288,6 +1403,9 @@ class FlashCausalLM(Model):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths,
|
cache_lengths=cache_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
max_current_length=max_s,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
|
@ -1621,6 +1739,9 @@ class FlashCausalLM(Model):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
cache_lengths=batch.cache_lengths,
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
@ -1661,6 +1782,9 @@ class FlashCausalLM(Model):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
cache_lengths=batch.cache_lengths,
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
)
|
)
|
||||||
# assert block_tables.shape[0] >= slots.shape[0]
|
# assert block_tables.shape[0] >= slots.shape[0]
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
@ -1756,7 +1880,6 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
next_adapter_indices = batch.adapter_meta.adapter_indices
|
|
||||||
|
|
||||||
finished_prefilling = True
|
finished_prefilling = True
|
||||||
next_chunk_lengths = []
|
next_chunk_lengths = []
|
||||||
|
@ -1827,13 +1950,12 @@ class FlashCausalLM(Model):
|
||||||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||||
# instantly become of shape [BATCH_SIZE]
|
# instantly become of shape [BATCH_SIZE]
|
||||||
if prefill and finished_prefilling:
|
if prefill and finished_prefilling:
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
batch.position_ids = batch.position_ids[indices]
|
||||||
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
batch.slot_indices = batch.slot_indices[indices]
|
||||||
len(batch)
|
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||||
)
|
indices
|
||||||
elif not prefill:
|
]
|
||||||
next_position_ids = batch.position_ids
|
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
|
@ -1852,8 +1974,10 @@ class FlashCausalLM(Model):
|
||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
index = 0
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
|
cu_accepted_ids = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(accepted_ids, dim=0), (1, 0)
|
||||||
|
)
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
|
@ -1865,21 +1989,6 @@ class FlashCausalLM(Model):
|
||||||
request_was_prefilling,
|
request_was_prefilling,
|
||||||
request_is_prefilling,
|
request_is_prefilling,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
if prefill and finished_prefilling:
|
|
||||||
# Indexing metadata
|
|
||||||
_start_index = cumulative_length
|
|
||||||
end_index = cumulative_length + input_length
|
|
||||||
|
|
||||||
# Initialize position_ids
|
|
||||||
# In decode, we do not need this as we can just increment position ids
|
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
|
||||||
|
|
||||||
# Initialize adapter indices
|
|
||||||
# In decode, we only have one token per row in the batch, so grab last index
|
|
||||||
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
|
|
||||||
end_index - 1
|
|
||||||
]
|
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||||
if request.prefill_logprobs and request_was_prefilling:
|
if request.prefill_logprobs and request_was_prefilling:
|
||||||
|
@ -1898,25 +2007,39 @@ class FlashCausalLM(Model):
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = ids
|
prefill_tokens_indices = ids
|
||||||
|
|
||||||
if not request_is_prefilling:
|
# If the device does not support triton, we copy one by one
|
||||||
|
if not request_is_prefilling and not has_triton():
|
||||||
# Only save tokens if we are done prefilling for this request
|
# Only save tokens if we are done prefilling for this request
|
||||||
for j in range(n_accepted_ids):
|
batch.all_input_ids_tensor[
|
||||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
i,
|
||||||
next_input_ids[index + j]
|
batch.cache_lengths_tensor[i]
|
||||||
)
|
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
||||||
index += n_accepted_ids
|
+ batch.input_lengths[i]
|
||||||
|
+ accepted_ids[i],
|
||||||
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
# If the device support triton, we can use a fused kernel
|
||||||
|
if has_triton():
|
||||||
|
copy_next_input_ids_inplace(
|
||||||
|
speculate + 1,
|
||||||
|
batch.all_input_ids_tensor,
|
||||||
|
batch.cache_lengths_tensor,
|
||||||
|
batch.input_lengths_tensor,
|
||||||
|
batch.prompt_lengths_tensor,
|
||||||
|
next_input_ids,
|
||||||
|
cu_accepted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
# These values can be updated without a GPU -> CPU sync
|
# These values can be updated without a GPU -> CPU sync
|
||||||
if not prefill or (prefill and finished_prefilling):
|
if not prefill or (prefill and finished_prefilling):
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids += accepted_ids
|
||||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||||
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
|
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
|
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
|
||||||
|
@ -2093,8 +2216,10 @@ class FlashCausalLM(Model):
|
||||||
# processing
|
# processing
|
||||||
stopped = False
|
stopped = False
|
||||||
new_input_length = next_chunk_lengths[i]
|
new_input_length = next_chunk_lengths[i]
|
||||||
|
new_cache_length = cache_length + input_length
|
||||||
else:
|
else:
|
||||||
new_input_length = n_accepted_ids
|
new_input_length = 1
|
||||||
|
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
|
@ -2206,12 +2331,10 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
index += n_accepted_ids
|
index += n_accepted_ids
|
||||||
current_cache_length = cache_length + input_length
|
batch.cache_lengths[i] = new_cache_length
|
||||||
batch.cache_lengths[i] = current_cache_length
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
current_input_length = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
batch.max_input_length = max(batch.max_input_length, current_input_length)
|
current_length = new_cache_length + new_input_length
|
||||||
batch.input_lengths[i] = current_input_length
|
|
||||||
current_length = current_cache_length + current_input_length
|
|
||||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||||
|
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
|
@ -2258,11 +2381,6 @@ class FlashCausalLM(Model):
|
||||||
state=(
|
state=(
|
||||||
state if state is not None else self.prefill_with_paged_kv_state
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
),
|
),
|
||||||
# block_tables=block_tables_to_ragged(
|
|
||||||
# block_tables=block_tables,
|
|
||||||
# input_lengths=input_lengths,
|
|
||||||
# cache_lengths=cache_lengths,
|
|
||||||
# ),
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
||||||
|
@ -2287,23 +2405,3 @@ class FlashCausalLM(Model):
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
window_left=self.sliding_window,
|
window_left=self.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def block_tables_to_ragged(
|
|
||||||
*, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
|
||||||
assert len(input_lengths) == len(cache_lengths)
|
|
||||||
|
|
||||||
total_len = sum(input_lengths) + sum(cache_lengths)
|
|
||||||
block_tables_ragged = torch.empty(
|
|
||||||
total_len, dtype=torch.int32, device=block_tables.device
|
|
||||||
)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
|
|
||||||
seq_len = cache_length + input_length
|
|
||||||
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
|
||||||
offset += seq_len
|
|
||||||
|
|
||||||
return block_tables_ragged
|
|
||||||
|
|
|
@ -0,0 +1,347 @@
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List, Optional
|
||||||
|
from torch.utils._triton import has_triton as has_triton_torch
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
SYSTEM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
|
_HAS_TRITON: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
def has_triton():
|
||||||
|
global _HAS_TRITON
|
||||||
|
if _HAS_TRITON is None:
|
||||||
|
# FIXME: it seems that has_triton_torch is bugged on RocM
|
||||||
|
# For now, only accept cuda
|
||||||
|
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
|
||||||
|
if _HAS_TRITON:
|
||||||
|
log_master(logger.info, "Using optimized Triton indexing kernels.")
|
||||||
|
|
||||||
|
return _HAS_TRITON
|
||||||
|
|
||||||
|
|
||||||
|
def block_tables_to_padded(
|
||||||
|
max_blocks: int,
|
||||||
|
cu_seqlen: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
block_tables_ragged: torch.Tensor,
|
||||||
|
):
|
||||||
|
def grid(meta):
|
||||||
|
return (
|
||||||
|
triton.cdiv(max_blocks, meta["BLOCK_SIZE"]),
|
||||||
|
len(block_tables),
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_block_tables_to_padded[grid](
|
||||||
|
cu_seqlen,
|
||||||
|
block_tables,
|
||||||
|
block_tables_ragged,
|
||||||
|
block_tables.shape[1],
|
||||||
|
BLOCK_SIZE=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def block_tables_to_ragged(
|
||||||
|
*,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
input_lengths: List[int],
|
||||||
|
cache_lengths: List[int],
|
||||||
|
input_lengths_tensor: torch.Tensor,
|
||||||
|
cache_lengths_tensor: torch.Tensor,
|
||||||
|
max_current_length: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
|
assert len(input_lengths) == len(cache_lengths)
|
||||||
|
|
||||||
|
total_len = sum(input_lengths) + sum(cache_lengths)
|
||||||
|
block_tables_ragged = torch.empty(
|
||||||
|
total_len, dtype=torch.int32, device=block_tables.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
cu_seqlen = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
def grid(meta):
|
||||||
|
return (
|
||||||
|
triton.cdiv(max_current_length, meta["BLOCK_SIZE"]),
|
||||||
|
len(cache_lengths),
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_block_tables_to_ragged[grid](
|
||||||
|
cu_seqlen,
|
||||||
|
block_tables,
|
||||||
|
block_tables_ragged,
|
||||||
|
block_tables.shape[1],
|
||||||
|
BLOCK_SIZE=256,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
for i, (input_length, cache_length) in enumerate(
|
||||||
|
zip(input_lengths, cache_lengths)
|
||||||
|
):
|
||||||
|
seq_len = cache_length + input_length
|
||||||
|
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
return block_tables_ragged
|
||||||
|
|
||||||
|
|
||||||
|
def copy_next_input_ids_inplace(
|
||||||
|
max_next_input_ids: int,
|
||||||
|
all_input_ids: torch.Tensor,
|
||||||
|
cache_lengths: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
prompt_lengths: torch.Tensor,
|
||||||
|
next_input_ids: torch.Tensor,
|
||||||
|
cu_accepted_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
def grid(meta):
|
||||||
|
return (
|
||||||
|
triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]),
|
||||||
|
len(all_input_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_copy_next_input_ids_inplace[grid](
|
||||||
|
all_input_ids,
|
||||||
|
cache_lengths,
|
||||||
|
input_lengths,
|
||||||
|
prompt_lengths,
|
||||||
|
next_input_ids,
|
||||||
|
cu_accepted_ids,
|
||||||
|
all_input_ids.shape[1],
|
||||||
|
BLOCK_SIZE=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_position_slot_ids(
|
||||||
|
max_input_length: int,
|
||||||
|
cache_lengths: torch.Tensor,
|
||||||
|
cu_seqlen: torch.Tensor,
|
||||||
|
cu_slots: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
slot_indices: torch.Tensor,
|
||||||
|
):
|
||||||
|
def grid(meta):
|
||||||
|
return (
|
||||||
|
triton.cdiv(max_input_length, meta["BLOCK_SIZE"]),
|
||||||
|
len(cache_lengths),
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_prepare_position_slot_ids[grid](
|
||||||
|
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def slots_filtering(
|
||||||
|
max_slots: int,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
filtered_slots: torch.Tensor,
|
||||||
|
cu_slots: torch.Tensor,
|
||||||
|
slots_start: torch.Tensor,
|
||||||
|
):
|
||||||
|
def grid(meta):
|
||||||
|
return (
|
||||||
|
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
|
||||||
|
len(slots_start),
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_slots_filtering[grid](
|
||||||
|
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def triton_slots_filtering(
|
||||||
|
# Inputs
|
||||||
|
slots_ptr,
|
||||||
|
filtered_slots_ptr,
|
||||||
|
slots_start_ptr,
|
||||||
|
cu_slots_ptr,
|
||||||
|
# Const values
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
# Position in batch
|
||||||
|
bid = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
filter_start = tl.load(slots_start_ptr + bid)
|
||||||
|
|
||||||
|
slot_start = tl.load(cu_slots_ptr + bid)
|
||||||
|
slot_end = tl.load(cu_slots_ptr + bid + 1)
|
||||||
|
|
||||||
|
mask = (slot_start + block_arange) < slot_end
|
||||||
|
|
||||||
|
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
|
||||||
|
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def triton_block_tables_to_padded(
|
||||||
|
# Inputs
|
||||||
|
cu_seqlen_ptr,
|
||||||
|
# Outputs
|
||||||
|
block_tables_ptr,
|
||||||
|
block_tables_ragged_ptr,
|
||||||
|
# Stride
|
||||||
|
stride_block_tables,
|
||||||
|
# Const values
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
# Position in batch
|
||||||
|
bid = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
seq_start = tl.load(cu_seqlen_ptr + bid)
|
||||||
|
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
|
||||||
|
|
||||||
|
mask = (seq_start + block_arange) < seq_end
|
||||||
|
|
||||||
|
blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)
|
||||||
|
tl.store(
|
||||||
|
block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def triton_block_tables_to_ragged(
|
||||||
|
# Inputs
|
||||||
|
cu_seqlen_ptr,
|
||||||
|
# Outputs
|
||||||
|
block_tables_ptr,
|
||||||
|
block_tables_ragged_ptr,
|
||||||
|
# Stride
|
||||||
|
stride_block_tables,
|
||||||
|
# Const values
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
# Position in batch
|
||||||
|
bid = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
seq_start = tl.load(cu_seqlen_ptr + bid)
|
||||||
|
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
|
||||||
|
|
||||||
|
mask = (seq_start + block_arange) < seq_end
|
||||||
|
|
||||||
|
blocks = tl.load(
|
||||||
|
block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask
|
||||||
|
)
|
||||||
|
tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def triton_copy_next_input_ids_inplace(
|
||||||
|
# Inputs
|
||||||
|
all_input_ids_ptr,
|
||||||
|
cache_lengths_ptr,
|
||||||
|
input_lengths_ptr,
|
||||||
|
prompt_lengths_ptr,
|
||||||
|
next_input_ids_ptr,
|
||||||
|
cu_accepted_ids_ptr,
|
||||||
|
# Stride
|
||||||
|
stride_all_input_ids,
|
||||||
|
# Const values
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
# Position in max_accepted_ids / BLOCK_SIZE
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
# Position in batch
|
||||||
|
bid = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
# Used for correctly indexing in all_input_ids
|
||||||
|
cache_length = tl.load(cache_lengths_ptr + bid)
|
||||||
|
input_length = tl.load(input_lengths_ptr + bid)
|
||||||
|
prompt_length = tl.load(prompt_lengths_ptr + bid)
|
||||||
|
|
||||||
|
# Start/End of next_input_ids for this request
|
||||||
|
next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)
|
||||||
|
next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)
|
||||||
|
|
||||||
|
# Mask values out of range
|
||||||
|
mask = (next_input_ids_start + block_arange) < next_input_ids_end
|
||||||
|
|
||||||
|
# Mask values for request still prefilling
|
||||||
|
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
|
||||||
|
|
||||||
|
mask = mask & decode_mask
|
||||||
|
|
||||||
|
# Load this request next input ids
|
||||||
|
next_input_ids = tl.load(
|
||||||
|
next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
|
||||||
|
tl.store(
|
||||||
|
all_input_ids_ptr
|
||||||
|
+ stride_all_input_ids * bid
|
||||||
|
+ cache_length
|
||||||
|
+ input_length
|
||||||
|
+ block_arange,
|
||||||
|
next_input_ids,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def triton_prepare_position_slot_ids(
|
||||||
|
# Inputs
|
||||||
|
cache_lengths_ptr,
|
||||||
|
cu_seqlen_ptr,
|
||||||
|
cu_slots_ptr,
|
||||||
|
# Outputs
|
||||||
|
position_ids_ptr,
|
||||||
|
slot_indices_ptr,
|
||||||
|
# Const values
|
||||||
|
BLOCK_SIZE: "tl.constexpr",
|
||||||
|
):
|
||||||
|
# Position in max_input_length / BLOCK_SIZE
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
# Position in batch
|
||||||
|
bid = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
cache_length = tl.load(cache_lengths_ptr + bid)
|
||||||
|
|
||||||
|
seq_start = tl.load(cu_seqlen_ptr + bid)
|
||||||
|
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
|
||||||
|
|
||||||
|
slot_start = tl.load(cu_slots_ptr + bid)
|
||||||
|
|
||||||
|
mask = (seq_start + block_arange) < seq_end
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
position_ids_ptr + seq_start + block_arange,
|
||||||
|
cache_length + block_arange,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
slot_indices_ptr + seq_start + block_arange,
|
||||||
|
slot_start + cache_length + block_arange,
|
||||||
|
mask=mask,
|
||||||
|
)
|
|
@ -14,11 +14,9 @@ from transformers import (
|
||||||
|
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
|
||||||
block_tables_to_ragged,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
@ -283,6 +281,9 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
cache_lengths=batch.cache_lengths,
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
@ -338,6 +339,9 @@ class MllamaCausalLM(VlmCausalLM):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
cache_lengths=batch.cache_lengths,
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
)
|
)
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -11,12 +11,12 @@ from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
block_tables_to_ragged,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -363,6 +363,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
cache_lengths=batch.cache_lengths,
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
@ -411,6 +414,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
cache_lengths=batch.cache_lengths,
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
)
|
)
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue