fix: improve find_segments via numpy diff (#2686)

This commit is contained in:
drbh 2024-11-18 09:51:06 -05:00 committed by GitHub
parent 52e48739a5
commit fea62e928f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 16 deletions

View File

@ -5,30 +5,25 @@
from typing import List, Tuple, Union
import torch
import numpy as np
# FIXME: this should be optimized
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []
if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu().tolist()
adapter_indices = adapter_indices.cpu().numpy()
elif isinstance(adapter_indices, list):
adapter_indices = np.array(adapter_indices)
start_index = 0
for i in range(1, len(adapter_indices)):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1])
start_index = i
change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1)
change_indices = np.nonzero(change_mask)[0]
# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
segments = [0]
segments.extend(change_indices[1:].tolist())
segments.append(len(adapter_indices))
segment_indices = adapter_indices[change_indices].tolist()
return segments, segment_indices