feat: baseline impl single request multi lora support
This commit is contained in:
parent
a046c303f7
commit
c661631225
|
@ -121,6 +121,7 @@ def download_weights(
|
|||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
merge_lora: bool = False,
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
|
@ -151,6 +152,9 @@ def download_weights(
|
|||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
# TODO: maybe reverse the default value of merge_lora?
|
||||
# currently by default we don't merge the weights with the base model
|
||||
if merge_lora:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="adapter_config.json"
|
||||
|
@ -163,6 +167,10 @@ def download_weights(
|
|||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
else:
|
||||
utils.peft.download_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
try:
|
||||
import json
|
||||
|
|
|
@ -92,7 +92,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
all_adapter_weights,
|
||||
lora_weights,
|
||||
lora_configs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
@ -126,36 +127,24 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.index = index
|
||||
self.adapter_weights = {}
|
||||
adapter_names = list(all_adapter_weights.keys())
|
||||
adapter_names = list(lora_weights.keys())
|
||||
|
||||
self.lora_a_matrix = torch.empty(
|
||||
(len(adapter_names), 2, 4096, 8),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
self.lora_b_matrix = torch.empty(
|
||||
(len(adapter_names), 2, 8, 4096),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
|
||||
self.pre_multiplied_lora_matrix = torch.empty(
|
||||
(len(adapter_names), 2, 4096, 4096),
|
||||
self.n_loras = len(adapter_names)
|
||||
self.pre_multiplied_lora_matrices = torch.empty(
|
||||
(self.n_loras, 2, self.hidden_size, self.hidden_size),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
|
||||
self.key_to_index = {}
|
||||
self.index_to_key = {}
|
||||
|
||||
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
|
||||
for adapter_index, adapter_name in enumerate(adapter_names):
|
||||
self.lora_alpha = 16.0
|
||||
self.lora_r = 8.0
|
||||
self.lora_alpha = lora_configs[adapter_name].lora_alpha
|
||||
self.lora_r = lora_configs[adapter_name].r
|
||||
self.lora_scale = self.lora_alpha / self.lora_r
|
||||
self.key_to_index[adapter_name] = adapter_index
|
||||
self.index_to_key[adapter_index] = adapter_name
|
||||
adapter_weights = all_adapter_weights[adapter_name]
|
||||
adapter_weights = lora_weights[adapter_name]
|
||||
for target_index, target in enumerate(["q", "v"]):
|
||||
adapter_weight_a = adapter_weights.get_tensor(
|
||||
f"{lora_prefix}.{target}_proj.lora_A.weight"
|
||||
|
@ -168,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
adapter_weight_b.T,
|
||||
).contiguous()
|
||||
|
||||
self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = (
|
||||
self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = (
|
||||
pre_multiplied_lora_matrix
|
||||
)
|
||||
|
||||
|
@ -209,16 +198,26 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
batch_size = query.size(0)
|
||||
query_adapted = (
|
||||
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 0])
|
||||
.squeeze(0)
|
||||
.view(batch_size, self.num_heads, self.head_size)
|
||||
)
|
||||
|
||||
value_adapted = (
|
||||
torch.bmm(hidden_states.unsqueeze(0), self.pre_multiplied_lora_matrix[:, 1])
|
||||
.squeeze(0)
|
||||
.view(batch_size, self.num_key_value_heads, self.head_size)
|
||||
# hidden states without LoRA
|
||||
hs_wl = hidden_states[lora_indices == -1]
|
||||
|
||||
adapted_query_states = [hs_wl]
|
||||
adapted_value_states = [hs_wl]
|
||||
|
||||
for ind in range(self.n_loras):
|
||||
mask = lora_indices == ind
|
||||
hs_sub = hidden_states[mask]
|
||||
mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0])
|
||||
mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1])
|
||||
adapted_query_states.append(mat_q)
|
||||
adapted_value_states.append(mat_v)
|
||||
|
||||
query_adapted = torch.cat(adapted_query_states, dim=0).view(
|
||||
batch_size, self.num_heads, self.head_size
|
||||
)
|
||||
value_adapted = torch.cat(adapted_value_states, dim=0).view(
|
||||
batch_size, self.num_key_value_heads, self.head_size
|
||||
)
|
||||
|
||||
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
||||
|
@ -328,14 +327,15 @@ class LlamaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights, all_adapter_weights):
|
||||
def __init__(self, index, prefix, config, weights, lora_weights, lora_configs):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
all_adapter_weights=all_adapter_weights,
|
||||
lora_weights=lora_weights,
|
||||
lora_configs=lora_configs,
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
|
@ -391,7 +391,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, all_adapter_weights):
|
||||
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -408,7 +408,8 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
all_adapter_weights=all_adapter_weights,
|
||||
lora_weights=lora_weights,
|
||||
lora_configs=lora_configs,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -471,7 +472,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, all_adapter_weights):
|
||||
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
|
@ -480,7 +481,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights, all_adapter_weights)
|
||||
self.model = FlashLlamaModel(
|
||||
prefix, config, weights, lora_weights, lora_configs
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
|
|
|
@ -1069,7 +1069,8 @@ class FlashCausalLM(Model):
|
|||
for i, r in enumerate(batch.requests):
|
||||
if r.adapter_id:
|
||||
lora_index = self.model.get_lora_index(r.adapter_id)
|
||||
lora_indices[i] = lora_index
|
||||
input_length = batch.input_lengths[i]
|
||||
lora_indices[i : i + input_length] = lora_index
|
||||
batch_lora_adapter_mask[i] = True
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
|
|
|
@ -16,11 +16,11 @@ from text_generation_server.utils import (
|
|||
Weights,
|
||||
hub,
|
||||
)
|
||||
from text_generation_server.utils.weights import load_adaptor_weights
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.lora import LoraConfig
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
|
@ -75,7 +75,9 @@ class FlashLlama(FlashCausalLM):
|
|||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights, all_adapter_weights)
|
||||
model = FlashLlamaForCausalLM(
|
||||
prefix, config, weights, lora_weights, lora_configs
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
|
|
@ -86,6 +86,18 @@ def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||
return filenames
|
||||
|
||||
|
||||
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_files_from_dir, that's also what is done there
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _get_cached_revision_directory(
|
||||
model_id: str, revision: Optional[str]
|
||||
) -> Optional[Path]:
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import json
|
||||
from text_generation_server.utils import (
|
||||
hub,
|
||||
)
|
||||
import os
|
||||
|
||||
|
||||
class LoraConfig:
|
||||
def __init__(
|
||||
self,
|
||||
alpha_pattern=None,
|
||||
auto_mapping=None,
|
||||
base_model_name_or_path="",
|
||||
bias="none",
|
||||
fan_in_fan_out=False,
|
||||
inference_mode=True,
|
||||
init_lora_weights=True,
|
||||
layer_replication=None,
|
||||
layers_pattern=None,
|
||||
layers_to_transform=None,
|
||||
loftq_config=None,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
megatron_config=None,
|
||||
megatron_core="megatron.core",
|
||||
modules_to_save=None,
|
||||
peft_type="LORA",
|
||||
r=8,
|
||||
rank_pattern=None,
|
||||
revision=None,
|
||||
target_modules=None,
|
||||
task_type="CAUSAL_LM",
|
||||
use_dora=False,
|
||||
use_rslora=False,
|
||||
):
|
||||
self.alpha_pattern = alpha_pattern or {}
|
||||
self.auto_mapping = auto_mapping
|
||||
self.base_model_name_or_path = base_model_name_or_path
|
||||
self.bias = bias
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
self.inference_mode = inference_mode
|
||||
self.init_lora_weights = init_lora_weights
|
||||
self.layer_replication = layer_replication
|
||||
self.layers_pattern = layers_pattern
|
||||
self.layers_to_transform = layers_to_transform
|
||||
self.loftq_config = loftq_config or {}
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_dropout = lora_dropout
|
||||
self.megatron_config = megatron_config
|
||||
self.megatron_core = megatron_core
|
||||
self.modules_to_save = modules_to_save
|
||||
self.peft_type = peft_type
|
||||
self.r = r
|
||||
self.rank_pattern = rank_pattern or {}
|
||||
self.revision = revision
|
||||
self.target_modules = target_modules or ["q_proj", "v_proj"]
|
||||
self.task_type = task_type
|
||||
self.use_dora = use_dora
|
||||
self.use_rslora = use_rslora
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename):
|
||||
with open(filename, "r") as f:
|
||||
json_data = json.load(f)
|
||||
return cls(**json_data)
|
||||
|
||||
# TODO: support fetching the model from the hub if it's not in the cache
|
||||
@classmethod
|
||||
def from_pretrained(cls, adapter_id, revision=None):
|
||||
d = hub._get_cached_revision_directory(adapter_id, revision)
|
||||
filename = os.path.join(d, "adapter_config.json")
|
||||
return cls.from_file(filename)
|
|
@ -43,3 +43,24 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||
model.config.save_pretrained(cache_dir)
|
||||
tokenizer.save_pretrained(cache_dir)
|
||||
|
||||
|
||||
def download_peft(model_id, revision, trust_remote_code):
|
||||
torch_dtype = torch.float16
|
||||
try:
|
||||
_model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
except Exception:
|
||||
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
logger.info("Peft model downloaded.")
|
||||
|
|
|
@ -10,27 +10,6 @@ import json
|
|||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
# TODO: improve how the weights are loaded
|
||||
def load_adaptor_weights(model_id, local_path, extension=".safetensors"):
|
||||
adapter_weights = {}
|
||||
if local_path.exists() and local_path.is_dir():
|
||||
local_files = list(local_path.glob(f"*{extension}"))
|
||||
if not local_files:
|
||||
raise FileNotFoundError(
|
||||
f"No local weights found in {model_id} with extension {extension}"
|
||||
)
|
||||
for filename in local_files:
|
||||
adapter_weights.update(load_file(filename))
|
||||
|
||||
# TODO: remove (no need to sort)
|
||||
# sorted on the the layer number (index 4 in the key)
|
||||
sorted_keys = sorted(
|
||||
adapter_weights.keys(),
|
||||
key=lambda x: int(x.split(".")[4]),
|
||||
)
|
||||
return (adapter_weights, sorted_keys)
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue