diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py index b3f92369..fd8be563 100644 --- a/server/text_generation_server/utils/segments.py +++ b/server/text_generation_server/utils/segments.py @@ -5,30 +5,25 @@ from typing import List, Tuple, Union import torch +import numpy as np -# FIXME: this should be optimized def find_segments( adapter_indices: Union[torch.Tensor, List[int]] ) -> Tuple[List[int], List[int]]: - segments = [0] - segment_indices = [] - if isinstance(adapter_indices, torch.Tensor): - # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first - adapter_indices = adapter_indices.cpu().tolist() + adapter_indices = adapter_indices.cpu().numpy() + elif isinstance(adapter_indices, list): + adapter_indices = np.array(adapter_indices) - start_index = 0 - for i in range(1, len(adapter_indices)): - if adapter_indices[i] != adapter_indices[i - 1]: - segments.append(i) - segment_indices.append(adapter_indices[i - 1]) - start_index = i + change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1) + change_indices = np.nonzero(change_mask)[0] - # Handle the last segment - if start_index < len(adapter_indices): - segments.append(len(adapter_indices)) - segment_indices.append(adapter_indices[-1]) + segments = [0] + segments.extend(change_indices[1:].tolist()) + segments.append(len(adapter_indices)) + + segment_indices = adapter_indices[change_indices].tolist() return segments, segment_indices