feat: support base model generation and refactors
This commit is contained in:
parent
43ec9dfe32
commit
611225f017
|
@ -8,7 +8,6 @@ from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
LORA = "lora"
|
|
||||||
from text_generation_server.adapters.weights import (
|
from text_generation_server.adapters.weights import (
|
||||||
AdapterBatchMetadata,
|
AdapterBatchMetadata,
|
||||||
AdapterWeights,
|
AdapterWeights,
|
||||||
|
@ -246,7 +245,7 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def key(cls) -> str:
|
def key(cls) -> str:
|
||||||
return LORA
|
return "lora"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
@ -279,9 +278,12 @@ class BatchLoraWeights(BatchAdapterWeights):
|
||||||
}
|
}
|
||||||
|
|
||||||
max_rank = max(
|
max_rank = max(
|
||||||
adapter_weights[idx].lora_a_r
|
(
|
||||||
for idx in segment_indices
|
adapter_weights[idx].lora_a_r
|
||||||
if idx in adapter_weights
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
),
|
||||||
|
default=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill or max_rank > BGMV_MAX_RANK:
|
if prefill or max_rank > BGMV_MAX_RANK:
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
#############
|
|
||||||
from abc import ABC, abstractclassmethod
|
from abc import ABC, abstractclassmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -7,10 +6,6 @@ from typing import Dict, List, Optional, Set, Type
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
LORA = "lora"
|
|
||||||
LM_HEAD = "lm_head"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterBatchMetadata:
|
class AdapterBatchMetadata:
|
||||||
# [batch_size]
|
# [batch_size]
|
||||||
|
@ -127,7 +122,7 @@ class AdapterBatchData:
|
||||||
if v.is_empty():
|
if v.is_empty():
|
||||||
continue
|
continue
|
||||||
data[k] = v.get_data(
|
data[k] = v.get_data(
|
||||||
meta, prefill, prefill_head_indices if k == LM_HEAD else None
|
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||||
)
|
)
|
||||||
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||||
|
|
||||||
|
@ -135,7 +130,7 @@ class AdapterBatchData:
|
||||||
# TODO(travis): refactor to be less coupled to lora implementation
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
ranks = set()
|
ranks = set()
|
||||||
for layer_data in self.data.values():
|
for layer_data in self.data.values():
|
||||||
lora_data = layer_data.get(LORA)
|
lora_data = layer_data.get("lora")
|
||||||
if lora_data is None:
|
if lora_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,6 @@ def serve(
|
||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
# TODO: determine if this api makes sense
|
|
||||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||||
|
|
||||||
# split on comma and strip whitespace
|
# split on comma and strip whitespace
|
||||||
|
|
|
@ -52,16 +52,6 @@ if SYSTEM == "rocm":
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
# Constants
|
|
||||||
Q_PROJ = "q_proj"
|
|
||||||
K_PROJ = "k_proj"
|
|
||||||
V_PROJ = "v_proj"
|
|
||||||
O_PROJ = "o_proj"
|
|
||||||
|
|
||||||
GATE_PROJ = "gate_proj"
|
|
||||||
UP_PROJ = "up_proj"
|
|
||||||
DOWN_PROJ = "down_proj"
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights, layer_id):
|
def load_attention(config, prefix, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
|
@ -100,7 +90,7 @@ def load_attention(config, prefix, weights, layer_id):
|
||||||
return TensorParallelMultiAdapterLinear.load(
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
base_layer,
|
base_layer,
|
||||||
layer_id,
|
layer_id,
|
||||||
[Q_PROJ, K_PROJ, V_PROJ],
|
["q_proj", "k_proj", "v_proj"],
|
||||||
sizes=[
|
sizes=[
|
||||||
head_size * config.num_attention_heads,
|
head_size * config.num_attention_heads,
|
||||||
head_size * config.num_key_value_heads,
|
head_size * config.num_key_value_heads,
|
||||||
|
@ -160,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
o_proj,
|
o_proj,
|
||||||
index,
|
index,
|
||||||
O_PROJ,
|
"o_proj",
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -268,7 +258,7 @@ class LlamaMLP(nn.Module):
|
||||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
gate_up_proj,
|
gate_up_proj,
|
||||||
index,
|
index,
|
||||||
[GATE_PROJ, UP_PROJ],
|
["gate_proj", "up_proj"],
|
||||||
sizes=[
|
sizes=[
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
|
@ -286,7 +276,7 @@ class LlamaMLP(nn.Module):
|
||||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
down_proj,
|
down_proj,
|
||||||
index,
|
index,
|
||||||
DOWN_PROJ,
|
"down_proj",
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,31 +20,17 @@ from text_generation_server.utils import (
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.lora import LoraConfig
|
|
||||||
|
|
||||||
Q_PROJ = "q_proj"
|
|
||||||
K_PROJ = "k_proj"
|
|
||||||
V_PROJ = "v_proj"
|
|
||||||
O_PROJ = "o_proj"
|
|
||||||
|
|
||||||
GATE_PROJ = "gate_proj"
|
|
||||||
UP_PROJ = "up_proj"
|
|
||||||
DOWN_PROJ = "down_proj"
|
|
||||||
|
|
||||||
LM_HEAD = "lm_head"
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
|
||||||
ADAPTER_LAYERS = [
|
ADAPTER_LAYERS = [
|
||||||
Q_PROJ,
|
"q_proj",
|
||||||
K_PROJ,
|
"k_proj",
|
||||||
V_PROJ,
|
"v_proj",
|
||||||
O_PROJ,
|
"o_proj",
|
||||||
GATE_PROJ,
|
"gate_proj",
|
||||||
UP_PROJ,
|
"up_proj",
|
||||||
DOWN_PROJ,
|
"down_proj",
|
||||||
] # LM_HEAD
|
]
|
||||||
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
|
@ -123,32 +109,32 @@ class FlashLlama(FlashCausalLM):
|
||||||
|
|
||||||
prefix = "model.layers"
|
prefix = "model.layers"
|
||||||
for i, layer in enumerate(self.model.model.layers):
|
for i, layer in enumerate(self.model.model.layers):
|
||||||
layer_weights[(i, Q_PROJ)] = (
|
layer_weights[(i, "q_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
)
|
)
|
||||||
layer_weights[(i, K_PROJ)] = (
|
layer_weights[(i, "k_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
f"{prefix}.{i}.self_attn.k_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
)
|
)
|
||||||
layer_weights[(i, V_PROJ)] = (
|
layer_weights[(i, "v_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
f"{prefix}.{i}.self_attn.v_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
)
|
)
|
||||||
layer_weights[(i, O_PROJ)] = (
|
layer_weights[(i, "o_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
f"{prefix}.{i}.self_attn.o_proj",
|
||||||
layer.self_attn.o_proj,
|
layer.self_attn.o_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights[(i, GATE_PROJ)] = (
|
layer_weights[(i, "gate_proj")] = (
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
f"{prefix}.{i}.mlp.gate_proj",
|
||||||
layer.mlp.gate_up_proj,
|
layer.mlp.gate_up_proj,
|
||||||
)
|
)
|
||||||
layer_weights[(i, UP_PROJ)] = (
|
layer_weights[(i, "up_proj")] = (
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
f"{prefix}.{i}.mlp.up_proj",
|
||||||
layer.mlp.gate_up_proj,
|
layer.mlp.gate_up_proj,
|
||||||
)
|
)
|
||||||
layer_weights[(i, DOWN_PROJ)] = (
|
layer_weights[(i, "down_proj")] = (
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
f"{prefix}.{i}.mlp.down_proj",
|
||||||
layer.mlp.down_proj,
|
layer.mlp.down_proj,
|
||||||
)
|
)
|
||||||
|
@ -162,7 +148,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
def default_traced_adapter_layers(self) -> List[str]:
|
||||||
return [Q_PROJ, V_PROJ]
|
return ["q_proj", "v_proj"]
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||||
|
|
|
@ -21,29 +21,16 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
Q_PROJ = "q_proj"
|
|
||||||
K_PROJ = "k_proj"
|
|
||||||
V_PROJ = "v_proj"
|
|
||||||
O_PROJ = "o_proj"
|
|
||||||
|
|
||||||
GATE_PROJ = "gate_proj"
|
|
||||||
UP_PROJ = "up_proj"
|
|
||||||
DOWN_PROJ = "down_proj"
|
|
||||||
|
|
||||||
LM_HEAD = "lm_head"
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
|
||||||
ADAPTER_LAYERS = [
|
ADAPTER_LAYERS = [
|
||||||
Q_PROJ,
|
"q_proj",
|
||||||
K_PROJ,
|
"k_proj",
|
||||||
V_PROJ,
|
"v_proj",
|
||||||
O_PROJ,
|
"o_proj",
|
||||||
GATE_PROJ,
|
"gate_proj",
|
||||||
UP_PROJ,
|
"up_proj",
|
||||||
DOWN_PROJ,
|
"down_proj",
|
||||||
] # LM_HEAD
|
]
|
||||||
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
class BaseFlashMistral(FlashCausalLM):
|
class BaseFlashMistral(FlashCausalLM):
|
||||||
|
@ -133,37 +120,37 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
|
|
||||||
prefix = "model.layers"
|
prefix = "model.layers"
|
||||||
for i, layer in enumerate(self.model.model.layers):
|
for i, layer in enumerate(self.model.model.layers):
|
||||||
layer_weights[(i, Q_PROJ)] = (
|
layer_weights[(i, "q_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
)
|
)
|
||||||
layer_weights[(i, K_PROJ)] = (
|
layer_weights[(i, "k_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.k_proj",
|
f"{prefix}.{i}.self_attn.k_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
)
|
)
|
||||||
layer_weights[(i, V_PROJ)] = (
|
layer_weights[(i, "v_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.v_proj",
|
f"{prefix}.{i}.self_attn.v_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
)
|
)
|
||||||
layer_weights[(i, O_PROJ)] = (
|
layer_weights[(i, "o_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.o_proj",
|
f"{prefix}.{i}.self_attn.o_proj",
|
||||||
layer.self_attn.o_proj,
|
layer.self_attn.o_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights[(i, GATE_PROJ)] = (
|
layer_weights[(i, "gate_proj")] = (
|
||||||
f"{prefix}.{i}.mlp.gate_proj",
|
f"{prefix}.{i}.mlp.gate_proj",
|
||||||
layer.mlp.gate_up_proj,
|
layer.mlp.gate_up_proj,
|
||||||
)
|
)
|
||||||
layer_weights[(i, UP_PROJ)] = (
|
layer_weights[(i, "up_proj")] = (
|
||||||
f"{prefix}.{i}.mlp.up_proj",
|
f"{prefix}.{i}.mlp.up_proj",
|
||||||
layer.mlp.gate_up_proj,
|
layer.mlp.gate_up_proj,
|
||||||
)
|
)
|
||||||
layer_weights[(i, DOWN_PROJ)] = (
|
layer_weights[(i, "down_proj")] = (
|
||||||
f"{prefix}.{i}.mlp.down_proj",
|
f"{prefix}.{i}.mlp.down_proj",
|
||||||
layer.mlp.down_proj,
|
layer.mlp.down_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
|
||||||
return layer_weights
|
return layer_weights
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -172,10 +159,10 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_traced_adapter_layers(self) -> List[str]:
|
def default_traced_adapter_layers(self) -> List[str]:
|
||||||
return [Q_PROJ, V_PROJ]
|
return ["q_proj", "v_proj"]
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
return layer_type in ROW_PARALLEL
|
return layer_type in ROW_PARALLEL
|
||||||
|
|
|
@ -251,7 +251,7 @@ def serve(
|
||||||
density=1.0,
|
density=1.0,
|
||||||
majority_sign_method=0,
|
majority_sign_method=0,
|
||||||
)
|
)
|
||||||
adapter_index = index
|
adapter_index = index + 1
|
||||||
adapter_to_index[adapter_id] = adapter_index
|
adapter_to_index[adapter_id] = adapter_index
|
||||||
model.load_adapter(
|
model.load_adapter(
|
||||||
adapter_parameters,
|
adapter_parameters,
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
import json
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
hub,
|
|
||||||
)
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
alpha_pattern=None,
|
|
||||||
auto_mapping=None,
|
|
||||||
base_model_name_or_path="",
|
|
||||||
bias="none",
|
|
||||||
fan_in_fan_out=False,
|
|
||||||
inference_mode=True,
|
|
||||||
init_lora_weights=True,
|
|
||||||
layer_replication=None,
|
|
||||||
layers_pattern=None,
|
|
||||||
layers_to_transform=None,
|
|
||||||
loftq_config=None,
|
|
||||||
lora_alpha=16,
|
|
||||||
lora_dropout=0.1,
|
|
||||||
megatron_config=None,
|
|
||||||
megatron_core="megatron.core",
|
|
||||||
modules_to_save=None,
|
|
||||||
peft_type="LORA",
|
|
||||||
r=8,
|
|
||||||
rank_pattern=None,
|
|
||||||
revision=None,
|
|
||||||
target_modules=None,
|
|
||||||
task_type="CAUSAL_LM",
|
|
||||||
use_dora=False,
|
|
||||||
use_rslora=False,
|
|
||||||
config_path=None,
|
|
||||||
):
|
|
||||||
self.alpha_pattern = alpha_pattern or {}
|
|
||||||
self.auto_mapping = auto_mapping
|
|
||||||
self.base_model_name_or_path = base_model_name_or_path
|
|
||||||
self.bias = bias
|
|
||||||
self.fan_in_fan_out = fan_in_fan_out
|
|
||||||
self.inference_mode = inference_mode
|
|
||||||
self.init_lora_weights = init_lora_weights
|
|
||||||
self.layer_replication = layer_replication
|
|
||||||
self.layers_pattern = layers_pattern
|
|
||||||
self.layers_to_transform = layers_to_transform
|
|
||||||
self.loftq_config = loftq_config or {}
|
|
||||||
self.lora_alpha = lora_alpha
|
|
||||||
self.lora_dropout = lora_dropout
|
|
||||||
self.megatron_config = megatron_config
|
|
||||||
self.megatron_core = megatron_core
|
|
||||||
self.modules_to_save = modules_to_save
|
|
||||||
self.peft_type = peft_type
|
|
||||||
self.r = r
|
|
||||||
self.rank_pattern = rank_pattern or {}
|
|
||||||
self.revision = revision
|
|
||||||
self.target_modules = target_modules or ["q_proj", "v_proj"]
|
|
||||||
self.task_type = task_type
|
|
||||||
self.use_dora = use_dora
|
|
||||||
self.use_rslora = use_rslora
|
|
||||||
self.config_path = config_path
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, filename):
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
json_data = json.load(f)
|
|
||||||
return cls(**json_data, config_path=filename)
|
|
||||||
|
|
||||||
# TODO: support fetching the model from the hub if it's not in the cache
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, adapter_id, revision=None):
|
|
||||||
d = hub._get_cached_revision_directory(adapter_id, revision)
|
|
||||||
filename = os.path.join(d, "adapter_config.json")
|
|
||||||
return cls.from_file(filename)
|
|
Loading…
Reference in New Issue