2023-06-01 03:41:35 -06:00
|
|
|
|
import inspect
|
2022-11-04 07:22:47 -06:00
|
|
|
|
import torch
|
|
|
|
|
|
2022-11-03 09:07:54 -06:00
|
|
|
|
from abc import ABC, abstractmethod
|
2024-06-25 12:46:27 -06:00
|
|
|
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
|
|
|
|
from collections import defaultdict
|
2023-07-21 02:59:00 -06:00
|
|
|
|
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2023-07-31 06:35:14 -06:00
|
|
|
|
from text_generation_server.models.types import Batch, Generation
|
2023-12-11 04:46:30 -07:00
|
|
|
|
from text_generation_server.utils.speculate import get_speculate
|
2023-04-21 07:36:29 -06:00
|
|
|
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
2024-06-25 12:46:27 -06:00
|
|
|
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
|
|
|
|
from text_generation_server.utils.adapter import (
|
|
|
|
|
load_and_merge_adapters,
|
|
|
|
|
AdapterParameters,
|
|
|
|
|
AdapterSource,
|
|
|
|
|
)
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
|
|
|
|
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
B = TypeVar("B", bound=Batch)
|
|
|
|
|
|
2023-07-24 03:43:58 -06:00
|
|
|
|
|
2022-11-03 09:07:54 -06:00
|
|
|
|
class Model(ABC):
|
2023-04-12 04:03:10 -06:00
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
2024-06-25 12:46:27 -06:00
|
|
|
|
model_id: str,
|
2023-05-16 15:23:27 -06:00
|
|
|
|
model: torch.nn.Module,
|
2023-04-12 04:03:10 -06:00
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2023-04-21 07:36:29 -06:00
|
|
|
|
requires_padding: bool,
|
|
|
|
|
dtype: torch.dtype,
|
2023-04-12 04:03:10 -06:00
|
|
|
|
device: torch.device,
|
2023-05-10 07:48:21 -06:00
|
|
|
|
rank: int = 0,
|
|
|
|
|
world_size: int = 1,
|
2023-09-28 01:55:47 -06:00
|
|
|
|
sliding_window: Optional[int] = None,
|
2023-12-11 04:46:30 -07:00
|
|
|
|
speculate: Optional[int] = None,
|
2024-06-25 12:46:27 -06:00
|
|
|
|
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
2023-04-12 04:03:10 -06:00
|
|
|
|
):
|
2024-06-25 12:46:27 -06:00
|
|
|
|
self.model_id = model_id
|
2023-05-16 15:23:27 -06:00
|
|
|
|
self.model = model.eval()
|
2022-11-04 07:22:47 -06:00
|
|
|
|
self.tokenizer = tokenizer
|
2024-04-25 11:41:50 -06:00
|
|
|
|
|
|
|
|
|
# all_special_ids is not set correctly if the rust tokenizer is unpacked
|
|
|
|
|
# TODO report this to transformers.
|
|
|
|
|
other_special_ids = {
|
|
|
|
|
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
|
|
|
|
|
}
|
2023-02-24 07:55:57 -07:00
|
|
|
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
2024-04-25 11:41:50 -06:00
|
|
|
|
self.all_special_ids.update(other_special_ids)
|
2023-04-21 07:36:29 -06:00
|
|
|
|
self.requires_padding = requires_padding
|
|
|
|
|
self.dtype = dtype
|
2022-11-04 07:22:47 -06:00
|
|
|
|
self.device = device
|
2023-05-10 07:48:21 -06:00
|
|
|
|
self.rank = rank
|
|
|
|
|
self.world_size = world_size
|
2023-12-15 06:56:17 -07:00
|
|
|
|
self.sliding_window = sliding_window if sliding_window != -1 else None
|
2023-06-01 03:41:35 -06:00
|
|
|
|
|
2024-06-25 12:46:27 -06:00
|
|
|
|
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
|
|
|
|
LayerAdapterWeights
|
|
|
|
|
)
|
|
|
|
|
self.target_to_layer = self.adapter_target_to_layer()
|
|
|
|
|
self.loaded_adapters = set()
|
|
|
|
|
self.static_adapter_id = adapter_id
|
|
|
|
|
|
2023-12-11 04:46:30 -07:00
|
|
|
|
if speculate is None:
|
|
|
|
|
speculate = get_speculate()
|
|
|
|
|
self.speculate = speculate
|
|
|
|
|
|
2023-06-01 03:41:35 -06:00
|
|
|
|
self.has_position_ids = (
|
|
|
|
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
|
|
|
|
is not None
|
|
|
|
|
)
|
|
|
|
|
|
2023-05-15 03:32:25 -06:00
|
|
|
|
self.check_initialized()
|
2022-11-04 07:22:47 -06:00
|
|
|
|
|
2023-04-21 07:36:29 -06:00
|
|
|
|
@property
|
|
|
|
|
def info(self) -> InfoResponse:
|
2023-09-28 01:55:47 -06:00
|
|
|
|
if self.requires_padding and self.sliding_window is not None:
|
|
|
|
|
raise NotImplementedError("sliding_window is not implemented with padding")
|
|
|
|
|
|
2023-04-21 07:36:29 -06:00
|
|
|
|
return InfoResponse(
|
|
|
|
|
requires_padding=self.requires_padding,
|
|
|
|
|
dtype=str(self.dtype),
|
|
|
|
|
device_type=self.device.type,
|
2023-09-28 01:55:47 -06:00
|
|
|
|
window_size=self.sliding_window,
|
2023-12-11 06:49:52 -07:00
|
|
|
|
speculate=self.speculate,
|
2023-04-21 07:36:29 -06:00
|
|
|
|
)
|
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
@property
|
2022-11-03 09:07:54 -06:00
|
|
|
|
@abstractmethod
|
2022-11-04 11:03:04 -06:00
|
|
|
|
def batch_type(self) -> Type[B]:
|
2022-11-03 09:07:54 -06:00
|
|
|
|
raise NotImplementedError
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
|
@abstractmethod
|
2023-12-14 07:59:38 -07:00
|
|
|
|
def generate_token(
|
|
|
|
|
self, batch: B
|
|
|
|
|
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
2022-11-04 11:03:04 -06:00
|
|
|
|
raise NotImplementedError
|
2023-03-06 05:22:58 -07:00
|
|
|
|
|
2023-07-19 01:31:25 -06:00
|
|
|
|
def warmup(self, batch: B) -> Optional[int]:
|
2023-06-30 11:09:59 -06:00
|
|
|
|
self.generate_token(batch)
|
2023-07-19 01:31:25 -06:00
|
|
|
|
return None
|
2023-06-30 11:09:59 -06:00
|
|
|
|
|
2023-04-11 08:38:22 -06:00
|
|
|
|
def decode_token(
|
|
|
|
|
self,
|
|
|
|
|
all_input_ids: List[int],
|
2023-05-16 15:23:27 -06:00
|
|
|
|
prefix_offset: int = 0,
|
|
|
|
|
read_offset: int = 0,
|
2023-09-27 04:13:45 -06:00
|
|
|
|
skip_special_tokens: bool = False,
|
2023-05-16 15:23:27 -06:00
|
|
|
|
) -> Tuple[str, int, int]:
|
2023-03-06 05:22:58 -07:00
|
|
|
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
2023-04-11 08:38:22 -06:00
|
|
|
|
|
2023-05-16 15:23:27 -06:00
|
|
|
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
|
|
|
|
# which decide to add a space or not depending on the surrounding ids.
|
|
|
|
|
prefix_text = self.tokenizer.decode(
|
2023-09-27 04:22:09 -06:00
|
|
|
|
all_input_ids[prefix_offset:read_offset],
|
|
|
|
|
skip_special_tokens=skip_special_tokens,
|
2023-05-16 15:23:27 -06:00
|
|
|
|
)
|
|
|
|
|
new_text = self.tokenizer.decode(
|
2023-09-27 04:13:45 -06:00
|
|
|
|
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
2023-05-16 15:23:27 -06:00
|
|
|
|
)
|
2023-04-11 08:38:22 -06:00
|
|
|
|
|
2023-05-16 15:23:27 -06:00
|
|
|
|
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
|
|
|
|
# utf-8 char at the end means it's a potential unfinished byte sequence
|
|
|
|
|
# from byte fallback tokenization.
|
|
|
|
|
# If it's in the middle, it's probably a real invalid id generated
|
|
|
|
|
# by the model
|
|
|
|
|
new_text = new_text[len(prefix_text) :]
|
|
|
|
|
return new_text, read_offset, len(all_input_ids)
|
2023-04-11 08:38:22 -06:00
|
|
|
|
else:
|
2023-05-16 15:23:27 -06:00
|
|
|
|
return "", prefix_offset, read_offset
|
2023-05-15 03:32:25 -06:00
|
|
|
|
|
|
|
|
|
def check_initialized(self):
|
|
|
|
|
uninitialized_parameters = []
|
|
|
|
|
for n, p in self.model.named_parameters():
|
|
|
|
|
if p.data.device == torch.device("meta"):
|
|
|
|
|
uninitialized_parameters.append(n)
|
|
|
|
|
if uninitialized_parameters:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
|
|
|
|
)
|
2024-06-25 12:46:27 -06:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def supports_adapter_loading(self) -> bool:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def adapter_layers(self) -> List[str]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def default_traced_adapter_layers(self) -> List[str]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
def is_row_parallel(self, layer_type: str) -> bool:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def max_speculative_tokens(self) -> int:
|
|
|
|
|
return max(
|
|
|
|
|
[
|
|
|
|
|
weights.max_speculative_tokens
|
|
|
|
|
for weights in self.layer_to_adapter_weights.values()
|
|
|
|
|
],
|
|
|
|
|
default=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_adapter(
|
|
|
|
|
self,
|
|
|
|
|
adapter_parameters: AdapterParameters,
|
|
|
|
|
adapter_source: AdapterSource,
|
|
|
|
|
adapter_index: int,
|
|
|
|
|
api_token: str,
|
|
|
|
|
dynamic: bool = True,
|
|
|
|
|
):
|
|
|
|
|
"""Loads adapter weights from disk / host memory on the GPU.
|
|
|
|
|
|
|
|
|
|
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
|
|
|
|
into model. Otherwise, the adapter weights are applied during the forward
|
|
|
|
|
pass and stored separately from the base model parameters.
|
|
|
|
|
"""
|
|
|
|
|
if adapter_index in self.loaded_adapters:
|
|
|
|
|
# Adapter already loaded
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self.supports_adapter_loading:
|
|
|
|
|
raise ValueError("This model does not support adapter loading.")
|
|
|
|
|
|
|
|
|
|
if dynamic and not self.dynamic_adapter_loading_enabled:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
|
|
|
f"and therefore does not support dynamic adapter loading. "
|
|
|
|
|
f"Please initialize a new model instance from the base model in "
|
|
|
|
|
f"order to use the dynamic adapter loading feature."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
|
|
|
|
)
|
|
|
|
|
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
|
|
|
|
(
|
|
|
|
|
module_map,
|
|
|
|
|
adapter_config,
|
|
|
|
|
adapter_weight_names,
|
|
|
|
|
adapter_tokenizer,
|
|
|
|
|
) = load_and_merge_adapters(
|
|
|
|
|
self.model_id,
|
|
|
|
|
adapter_parameters,
|
|
|
|
|
adapter_source,
|
|
|
|
|
adapter_index,
|
|
|
|
|
weight_names,
|
|
|
|
|
api_token,
|
|
|
|
|
False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
unused_weight_names = adapter_weight_names.copy()
|
|
|
|
|
for layer_name in self.adapter_layers:
|
|
|
|
|
adapter_weights = adapter_config.load_batched_adapter_weights(
|
|
|
|
|
self,
|
|
|
|
|
module_map,
|
|
|
|
|
layer_name,
|
|
|
|
|
unused_weight_names,
|
|
|
|
|
dynamic,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if adapter_weights is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
layer_weights = self.layer_to_adapter_weights[layer_name]
|
|
|
|
|
layer_weights.add_adapter(adapter_index, adapter_weights)
|
|
|
|
|
|
|
|
|
|
if len(unused_weight_names) > 0:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if adapter_tokenizer is not None:
|
|
|
|
|
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
|
|
|
|
|
|
|
|
|
self.loaded_adapters.add(adapter_index)
|
|
|
|
|
|
|
|
|
|
def offload_adapter(
|
|
|
|
|
self,
|
|
|
|
|
adapter_parameters: AdapterParameters,
|
|
|
|
|
adapter_source: AdapterSource,
|
|
|
|
|
adapter_index: int,
|
|
|
|
|
):
|
|
|
|
|
"""Offloads the adapter weights from GPU to CPU or disk."""
|
|
|
|
|
if adapter_index not in self.loaded_adapters:
|
|
|
|
|
# Adapter already offloaded
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self.supports_adapter_loading:
|
|
|
|
|
raise ValueError("This model does not support adapter loading.")
|
|
|
|
|
|
|
|
|
|
if not self.dynamic_adapter_loading_enabled:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"This model was initialized with the adapter {self.static_adapter_id} "
|
|
|
|
|
f"and therefore does not support dynamic adapter loading. "
|
|
|
|
|
f"Please initialize a new model instance from the base model in "
|
|
|
|
|
f"order to use the dynamic adapter loading feature."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for layer_name in self.adapter_layers:
|
|
|
|
|
if layer_name in self.layer_to_adapter_weights:
|
|
|
|
|
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
|
|
|
|
|
|
|
|
|
self.loaded_adapters.remove(adapter_index)
|