feat: support base model generation and refactors
This commit is contained in:
parent
43ec9dfe32
commit
611225f017
|
@ -8,7 +8,6 @@ from torch.distributed import ProcessGroup
|
|||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
LORA = "lora"
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchMetadata,
|
||||
AdapterWeights,
|
||||
|
@ -246,7 +245,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return LORA
|
||||
return "lora"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
@ -279,9 +278,12 @@ class BatchLoraWeights(BatchAdapterWeights):
|
|||
}
|
||||
|
||||
max_rank = max(
|
||||
(
|
||||
adapter_weights[idx].lora_a_r
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
),
|
||||
default=0,
|
||||
)
|
||||
|
||||
if prefill or max_rank > BGMV_MAX_RANK:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#############
|
||||
from abc import ABC, abstractclassmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
@ -7,10 +6,6 @@ from typing import Dict, List, Optional, Set, Type
|
|||
import torch
|
||||
|
||||
|
||||
LORA = "lora"
|
||||
LM_HEAD = "lm_head"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchMetadata:
|
||||
# [batch_size]
|
||||
|
@ -127,7 +122,7 @@ class AdapterBatchData:
|
|||
if v.is_empty():
|
||||
continue
|
||||
data[k] = v.get_data(
|
||||
meta, prefill, prefill_head_indices if k == LM_HEAD else None
|
||||
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||
)
|
||||
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||
|
||||
|
@ -135,7 +130,7 @@ class AdapterBatchData:
|
|||
# TODO(travis): refactor to be less coupled to lora implementation
|
||||
ranks = set()
|
||||
for layer_data in self.data.values():
|
||||
lora_data = layer_data.get(LORA)
|
||||
lora_data = layer_data.get("lora")
|
||||
if lora_data is None:
|
||||
continue
|
||||
|
||||
|
|
|
@ -78,7 +78,6 @@ def serve(
|
|||
if otlp_endpoint is not None:
|
||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||
|
||||
# TODO: determine if this api makes sense
|
||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||
|
||||
# split on comma and strip whitespace
|
||||
|
|
|
@ -52,16 +52,6 @@ if SYSTEM == "rocm":
|
|||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
# Constants
|
||||
Q_PROJ = "q_proj"
|
||||
K_PROJ = "k_proj"
|
||||
V_PROJ = "v_proj"
|
||||
O_PROJ = "o_proj"
|
||||
|
||||
GATE_PROJ = "gate_proj"
|
||||
UP_PROJ = "up_proj"
|
||||
DOWN_PROJ = "down_proj"
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
|
@ -100,7 +90,7 @@ def load_attention(config, prefix, weights, layer_id):
|
|||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer,
|
||||
layer_id,
|
||||
[Q_PROJ, K_PROJ, V_PROJ],
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
|
@ -160,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
O_PROJ,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
@ -268,7 +258,7 @@ class LlamaMLP(nn.Module):
|
|||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
[GATE_PROJ, UP_PROJ],
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
|
@ -286,7 +276,7 @@ class LlamaMLP(nn.Module):
|
|||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
DOWN_PROJ,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
|
|
@ -20,31 +20,17 @@ from text_generation_server.utils import (
|
|||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.lora import LoraConfig
|
||||
|
||||
Q_PROJ = "q_proj"
|
||||
K_PROJ = "k_proj"
|
||||
V_PROJ = "v_proj"
|
||||
O_PROJ = "o_proj"
|
||||
|
||||
GATE_PROJ = "gate_proj"
|
||||
UP_PROJ = "up_proj"
|
||||
DOWN_PROJ = "down_proj"
|
||||
|
||||
LM_HEAD = "lm_head"
|
||||
|
||||
|
||||
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
||||
ADAPTER_LAYERS = [
|
||||
Q_PROJ,
|
||||
K_PROJ,
|
||||
V_PROJ,
|
||||
O_PROJ,
|
||||
GATE_PROJ,
|
||||
UP_PROJ,
|
||||
DOWN_PROJ,
|
||||
] # LM_HEAD
|
||||
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
|
@ -123,32 +109,32 @@ class FlashLlama(FlashCausalLM):
|
|||
|
||||
prefix = "model.layers"
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
layer_weights[(i, Q_PROJ)] = (
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, K_PROJ)] = (
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, V_PROJ)] = (
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, O_PROJ)] = (
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
layer_weights[(i, GATE_PROJ)] = (
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, UP_PROJ)] = (
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, DOWN_PROJ)] = (
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
@ -162,7 +148,7 @@ class FlashLlama(FlashCausalLM):
|
|||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return [Q_PROJ, V_PROJ]
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||
|
|
|
@ -21,29 +21,16 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
Q_PROJ = "q_proj"
|
||||
K_PROJ = "k_proj"
|
||||
V_PROJ = "v_proj"
|
||||
O_PROJ = "o_proj"
|
||||
|
||||
GATE_PROJ = "gate_proj"
|
||||
UP_PROJ = "up_proj"
|
||||
DOWN_PROJ = "down_proj"
|
||||
|
||||
LM_HEAD = "lm_head"
|
||||
|
||||
|
||||
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
||||
ADAPTER_LAYERS = [
|
||||
Q_PROJ,
|
||||
K_PROJ,
|
||||
V_PROJ,
|
||||
O_PROJ,
|
||||
GATE_PROJ,
|
||||
UP_PROJ,
|
||||
DOWN_PROJ,
|
||||
] # LM_HEAD
|
||||
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
|
@ -133,37 +120,37 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
|
||||
prefix = "model.layers"
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
layer_weights[(i, Q_PROJ)] = (
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, K_PROJ)] = (
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, V_PROJ)] = (
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, O_PROJ)] = (
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
layer_weights[(i, GATE_PROJ)] = (
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, UP_PROJ)] = (
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, DOWN_PROJ)] = (
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
|
@ -172,10 +159,10 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return [Q_PROJ, V_PROJ]
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
|
|
@ -251,7 +251,7 @@ def serve(
|
|||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_index = index
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
import json
|
||||
from text_generation_server.utils import (
|
||||
hub,
|
||||
)
|
||||
import os
|
||||
|
||||
|
||||
class LoraConfig:
|
||||
def __init__(
|
||||
self,
|
||||
alpha_pattern=None,
|
||||
auto_mapping=None,
|
||||
base_model_name_or_path="",
|
||||
bias="none",
|
||||
fan_in_fan_out=False,
|
||||
inference_mode=True,
|
||||
init_lora_weights=True,
|
||||
layer_replication=None,
|
||||
layers_pattern=None,
|
||||
layers_to_transform=None,
|
||||
loftq_config=None,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
megatron_config=None,
|
||||
megatron_core="megatron.core",
|
||||
modules_to_save=None,
|
||||
peft_type="LORA",
|
||||
r=8,
|
||||
rank_pattern=None,
|
||||
revision=None,
|
||||
target_modules=None,
|
||||
task_type="CAUSAL_LM",
|
||||
use_dora=False,
|
||||
use_rslora=False,
|
||||
config_path=None,
|
||||
):
|
||||
self.alpha_pattern = alpha_pattern or {}
|
||||
self.auto_mapping = auto_mapping
|
||||
self.base_model_name_or_path = base_model_name_or_path
|
||||
self.bias = bias
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
self.inference_mode = inference_mode
|
||||
self.init_lora_weights = init_lora_weights
|
||||
self.layer_replication = layer_replication
|
||||
self.layers_pattern = layers_pattern
|
||||
self.layers_to_transform = layers_to_transform
|
||||
self.loftq_config = loftq_config or {}
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_dropout = lora_dropout
|
||||
self.megatron_config = megatron_config
|
||||
self.megatron_core = megatron_core
|
||||
self.modules_to_save = modules_to_save
|
||||
self.peft_type = peft_type
|
||||
self.r = r
|
||||
self.rank_pattern = rank_pattern or {}
|
||||
self.revision = revision
|
||||
self.target_modules = target_modules or ["q_proj", "v_proj"]
|
||||
self.task_type = task_type
|
||||
self.use_dora = use_dora
|
||||
self.use_rslora = use_rslora
|
||||
self.config_path = config_path
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename):
|
||||
with open(filename, "r") as f:
|
||||
json_data = json.load(f)
|
||||
return cls(**json_data, config_path=filename)
|
||||
|
||||
# TODO: support fetching the model from the hub if it's not in the cache
|
||||
@classmethod
|
||||
def from_pretrained(cls, adapter_id, revision=None):
|
||||
d = hub._get_cached_revision_directory(adapter_id, revision)
|
||||
filename = os.path.join(d, "adapter_config.json")
|
||||
return cls.from_file(filename)
|
Loading…
Reference in New Issue