hf_text-generation-inference/server/text_generation_server/utils/merges/utils.py

109 lines
3.9 KiB
Python
Raw Normal View History

Enable multiple LoRa adapters (#2010) * feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
2024-06-25 12:46:27 -06:00
# coding=utf-8
# From: https://github.com/huggingface/peft/pull/1364
# Copyright 2024-present the HuggingFace Inc. team.
# Modifications by Predibase, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import torch
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
"""
mask = torch.zeros_like(tensor).reshape(-1)
k = int(density * tensor.reshape(-1).shape[0])
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
mask[top_k[1]] = 1
return tensor * mask.reshape(tensor.shape)
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
pruned_tensor = tensor * mask
if rescale:
torch.div(input=pruned_tensor, other=density)
return pruned_tensor
def prune(
tensor: torch.Tensor,
density: float,
method: Literal["magnitude", "random"],
rescale: bool = False,
) -> torch.Tensor:
"""
Prune the values of task tensors based on the `method`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
if density >= 1:
return tensor
elif density < 0:
raise ValueError("Density should be >= 0, got {density}")
if method == "magnitude":
return magnitude_based_pruning(tensor, density)
elif method == "random":
return random_pruning(tensor, density, rescale=rescale)
else:
raise ValueError(f"Unknown method {method}")
def calculate_majority_sign_mask(
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
):
"""
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
Args:
tensor (`torch.Tensor`):The tensor to get the mask from.
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
"""
sign = tensor.sign()
if method == "total":
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
elif method == "frequency":
sign_magnitude = sign.sum(dim=0)
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
return sign == majority_sign
def disjoint_merge(task_tensors, majority_sign_mask):
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
num_params_preserved = majority_sign_mask.sum(dim=0)
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)