fix: improve find_segments via numpy diff (#2686)
This commit is contained in:
parent
52e48739a5
commit
fea62e928f
|
@ -5,30 +5,25 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# FIXME: this should be optimized
|
||||
def find_segments(
|
||||
adapter_indices: Union[torch.Tensor, List[int]]
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
segments = [0]
|
||||
segment_indices = []
|
||||
|
||||
if isinstance(adapter_indices, torch.Tensor):
|
||||
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
|
||||
adapter_indices = adapter_indices.cpu().tolist()
|
||||
adapter_indices = adapter_indices.cpu().numpy()
|
||||
elif isinstance(adapter_indices, list):
|
||||
adapter_indices = np.array(adapter_indices)
|
||||
|
||||
start_index = 0
|
||||
for i in range(1, len(adapter_indices)):
|
||||
if adapter_indices[i] != adapter_indices[i - 1]:
|
||||
segments.append(i)
|
||||
segment_indices.append(adapter_indices[i - 1])
|
||||
start_index = i
|
||||
change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1)
|
||||
change_indices = np.nonzero(change_mask)[0]
|
||||
|
||||
# Handle the last segment
|
||||
if start_index < len(adapter_indices):
|
||||
segments = [0]
|
||||
segments.extend(change_indices[1:].tolist())
|
||||
segments.append(len(adapter_indices))
|
||||
segment_indices.append(adapter_indices[-1])
|
||||
|
||||
segment_indices = adapter_indices[change_indices].tolist()
|
||||
|
||||
return segments, segment_indices
|
||||
|
||||
|
|
Loading…
Reference in New Issue