feat: baseline impl single request multi lora support

This commit is contained in:
drbh 2024-06-04 20:07:28 +00:00
parent a046c303f7
commit c661631225
8 changed files with 168 additions and 70 deletions

View File

@ -121,6 +121,7 @@ def download_weights(
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
merge_lora: bool = False,
):
# Remove default handler
logger.remove()
@ -151,18 +152,25 @@ def download_weights(
) is not None
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
# TODO: maybe reverse the default value of merge_lora?
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
try:
import json

View File

@ -92,7 +92,8 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str,
config,
weights,
all_adapter_weights,
lora_weights,
lora_configs,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -126,36 +127,24 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.index = index
self.adapter_weights = {}
adapter_names = list(all_adapter_weights.keys())
adapter_names = list(lora_weights.keys())
self.lora_a_matrix = torch.empty(
(len(adapter_names), 2, 4096, 8),
device=weights.device,
dtype=weights.dtype,
)
self.lora_b_matrix = torch.empty(
(len(adapter_names), 2, 8, 4096),
device=weights.device,
dtype=weights.dtype,
)
self.pre_multiplied_lora_matrix = torch.empty(
(len(adapter_names), 2, 4096, 4096),
self.n_loras = len(adapter_names)
self.pre_multiplied_lora_matrices = torch.empty(
(self.n_loras, 2, self.hidden_size, self.hidden_size),
device=weights.device,
dtype=weights.dtype,
)
self.key_to_index = {}
self.index_to_key = {}
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
for adapter_index, adapter_name in enumerate(adapter_names):
self.lora_alpha = 16.0
self.lora_r = 8.0
self.lora_alpha = lora_configs[adapter_name].lora_alpha
self.lora_r = lora_configs[adapter_name].r
self.lora_scale = self.lora_alpha / self.lora_r
self.key_to_index[adapter_name] = adapter_index
self.index_to_key[adapter_index] = adapter_name
adapter_weights = all_adapter_weights[adapter_name]
adapter_weights = lora_weights[adapter_name]
for target_index, target in enumerate(["q", "v"]):
adapter_weight_a = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_A.weight"
@ -168,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
adapter_weight_b.T,
).contiguous()
self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = (
self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = (
pre_multiplied_lora_matrix
)
@ -209,16 +198,26 @@ class FlashLlamaAttention(torch.nn.Module):
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
batch_size = query.size(0)
query_adapted = (
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0])
.squeeze(0)
.view(batch_size, self.num_heads, self.head_size)
)
value_adapted = (
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1])
.squeeze(0)
.view(batch_size, self.num_key_value_heads, self.head_size)
# hidden states without LoRA
hs_wl = hidden_states[lora_indices == -1]
adapted_query_states = [hs_wl]
adapted_value_states = [hs_wl]
for ind in range(self.n_loras):
mask = lora_indices == ind
hs_sub = hidden_states[mask]
mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0])
mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1])
adapted_query_states.append(mat_q)
adapted_value_states.append(mat_v)
query_adapted = torch.cat(adapted_query_states, dim=0).view(
batch_size, self.num_heads, self.head_size
)
value_adapted = torch.cat(adapted_value_states, dim=0).view(
batch_size, self.num_key_value_heads, self.head_size
)
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
@ -328,14 +327,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, all_adapter_weights):
def __init__(self, index, prefix, config, weights, lora_weights, lora_configs):
super().__init__()
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
all_adapter_weights=all_adapter_weights,
lora_weights=lora_weights,
lora_configs=lora_configs,
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -391,7 +391,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, all_adapter_weights):
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
super().__init__()
process_group = weights.process_group
@ -408,7 +408,8 @@ class FlashLlamaModel(torch.nn.Module):
),
config=config,
weights=weights,
all_adapter_weights=all_adapter_weights,
lora_weights=lora_weights,
lora_configs=lora_configs,
)
for layer_id in range(config.num_hidden_layers)
]
@ -471,7 +472,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, all_adapter_weights):
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
@ -480,7 +481,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
),
weights=weights,
)
self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
self.model = FlashLlamaModel(
prefix, config, weights, lora_weights, lora_configs
)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:

View File

@ -1069,7 +1069,8 @@ class FlashCausalLM(Model):
for i, r in enumerate(batch.requests):
if r.adapter_id:
lora_index = self.model.get_lora_index(r.adapter_id)
lora_indices[i] = lora_index
input_length = batch.input_lengths[i]
lora_indices[i : i + input_length] = lora_index
batch_lora_adapter_mask[i] = True
if cu_seqlen_prefill is not None or cuda_graph is None:

View File

@ -16,11 +16,11 @@ from text_generation_server.utils import (
Weights,
hub,
)
from text_generation_server.utils.weights import load_adaptor_weights
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.lora import LoraConfig
class FlashLlama(FlashCausalLM):
@ -75,7 +75,9 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
model = FlashLlamaForCausalLM(
prefix, config, weights, lora_weights, lora_configs
)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,

View File

@ -86,6 +86,18 @@ def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
return filenames
def _adapter_config_files_from_dir(d: Path) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(".json") and "arguments" not in f and "args" not in f
]
return filenames
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:

View File

@ -0,0 +1,72 @@
import json
from text_generation_server.utils import (
hub,
)
import os
class LoraConfig:
def __init__(
self,
alpha_pattern=None,
auto_mapping=None,
base_model_name_or_path="",
bias="none",
fan_in_fan_out=False,
inference_mode=True,
init_lora_weights=True,
layer_replication=None,
layers_pattern=None,
layers_to_transform=None,
loftq_config=None,
lora_alpha=16,
lora_dropout=0.1,
megatron_config=None,
megatron_core="megatron.core",
modules_to_save=None,
peft_type="LORA",
r=8,
rank_pattern=None,
revision=None,
target_modules=None,
task_type="CAUSAL_LM",
use_dora=False,
use_rslora=False,
):
self.alpha_pattern = alpha_pattern or {}
self.auto_mapping = auto_mapping
self.base_model_name_or_path = base_model_name_or_path
self.bias = bias
self.fan_in_fan_out = fan_in_fan_out
self.inference_mode = inference_mode
self.init_lora_weights = init_lora_weights
self.layer_replication = layer_replication
self.layers_pattern = layers_pattern
self.layers_to_transform = layers_to_transform
self.loftq_config = loftq_config or {}
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.megatron_config = megatron_config
self.megatron_core = megatron_core
self.modules_to_save = modules_to_save
self.peft_type = peft_type
self.r = r
self.rank_pattern = rank_pattern or {}
self.revision = revision
self.target_modules = target_modules or ["q_proj", "v_proj"]
self.task_type = task_type
self.use_dora = use_dora
self.use_rslora = use_rslora
@classmethod
def from_file(cls, filename):
with open(filename, "r") as f:
json_data = json.load(f)
return cls(**json_data)
# TODO: support fetching the model from the hub if it's not in the cache
@classmethod
def from_pretrained(cls, adapter_id, revision=None):
d = hub._get_cached_revision_directory(adapter_id, revision)
filename = os.path.join(d, "adapter_config.json")
return cls.from_file(filename)

View File

@ -43,3 +43,24 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)
def download_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
try:
_model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model downloaded.")

View File

@ -10,27 +10,6 @@ import json
from text_generation_server.utils.log import log_once
# TODO: improve how the weights are loaded
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
adapter_weights = {}
if local_path.exists() and local_path.is_dir():
local_files = list(local_path.glob(f"*{extension}"))
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"
)
for filename in local_files:
adapter_weights.update(load_file(filename))
# TODO: remove (no need to sort)
# sorted on the the layer number (index 4 in the key)
sorted_keys = sorted(
adapter_weights.keys(),
key=lambda x: int(x.split(".")[4]),
)
return (adapter_weights, sorted_keys)
class Weights:
def __init__(
self,