feat: support base model generation and refactors

This commit is contained in:
drbh 2024-06-07 01:20:41 +00:00
parent 43ec9dfe32
commit 611225f017
8 changed files with 50 additions and 165 deletions

View File

@ -8,7 +8,6 @@ from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
LORA = "lora"
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
@ -246,7 +245,7 @@ class BatchLoraWeights(BatchAdapterWeights):
@classmethod
def key(cls) -> str:
return LORA
return "lora"
@classmethod
def load(
@ -279,9 +278,12 @@ class BatchLoraWeights(BatchAdapterWeights):
}
max_rank = max(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:

View File

@ -1,4 +1,3 @@
#############
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
@ -7,10 +6,6 @@ from typing import Dict, List, Optional, Set, Type
import torch
LORA = "lora"
LM_HEAD = "lm_head"
@dataclass
class AdapterBatchMetadata:
# [batch_size]
@ -127,7 +122,7 @@ class AdapterBatchData:
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == LM_HEAD else None
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
@ -135,7 +130,7 @@ class AdapterBatchData:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get(LORA)
lora_data = layer_data.get("lora")
if lora_data is None:
continue

View File

@ -78,7 +78,6 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
# TODO: determine if this api makes sense
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace

View File

@ -52,16 +52,6 @@ if SYSTEM == "rocm":
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
# Constants
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
def load_attention(config, prefix, weights, layer_id):
# Only defined in granite.
@ -100,7 +90,7 @@ def load_attention(config, prefix, weights, layer_id):
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[Q_PROJ, K_PROJ, V_PROJ],
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
@ -160,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
O_PROJ,
"o_proj",
process_group=weights.process_group,
)
@ -268,7 +258,7 @@ class LlamaMLP(nn.Module):
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
[GATE_PROJ, UP_PROJ],
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
@ -286,7 +276,7 @@ class LlamaMLP(nn.Module):
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
DOWN_PROJ,
"down_proj",
process_group=weights.process_group,
)

View File

@ -20,31 +20,17 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.lora import LoraConfig
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
LM_HEAD = "lm_head"
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [
Q_PROJ,
K_PROJ,
V_PROJ,
O_PROJ,
GATE_PROJ,
UP_PROJ,
DOWN_PROJ,
] # LM_HEAD
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
@ -123,32 +109,32 @@ class FlashLlama(FlashCausalLM):
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = (
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, K_PROJ)] = (
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, V_PROJ)] = (
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, O_PROJ)] = (
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, GATE_PROJ)] = (
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, UP_PROJ)] = (
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, DOWN_PROJ)] = (
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
@ -162,7 +148,7 @@ class FlashLlama(FlashCausalLM):
@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)

View File

@ -21,29 +21,16 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
LM_HEAD = "lm_head"
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [
Q_PROJ,
K_PROJ,
V_PROJ,
O_PROJ,
GATE_PROJ,
UP_PROJ,
DOWN_PROJ,
] # LM_HEAD
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM):
@ -133,37 +120,37 @@ class BaseFlashMistral(FlashCausalLM):
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = (
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, K_PROJ)] = (
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, V_PROJ)] = (
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, O_PROJ)] = (
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, GATE_PROJ)] = (
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, UP_PROJ)] = (
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, DOWN_PROJ)] = (
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
return layer_weights
@property
@ -172,10 +159,10 @@ class BaseFlashMistral(FlashCausalLM):
@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -251,7 +251,7 @@ def serve(
density=1.0,
majority_sign_method=0,
)
adapter_index = index
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
model.load_adapter(
adapter_parameters,

View File

@ -1,74 +0,0 @@
import json
from text_generation_server.utils import (
hub,
)
import os
class LoraConfig:
def __init__(
self,
alpha_pattern=None,
auto_mapping=None,
base_model_name_or_path="",
bias="none",
fan_in_fan_out=False,
inference_mode=True,
init_lora_weights=True,
layer_replication=None,
layers_pattern=None,
layers_to_transform=None,
loftq_config=None,
lora_alpha=16,
lora_dropout=0.1,
megatron_config=None,
megatron_core="megatron.core",
modules_to_save=None,
peft_type="LORA",
r=8,
rank_pattern=None,
revision=None,
target_modules=None,
task_type="CAUSAL_LM",
use_dora=False,
use_rslora=False,
config_path=None,
):
self.alpha_pattern = alpha_pattern or {}
self.auto_mapping = auto_mapping
self.base_model_name_or_path = base_model_name_or_path
self.bias = bias
self.fan_in_fan_out = fan_in_fan_out
self.inference_mode = inference_mode
self.init_lora_weights = init_lora_weights
self.layer_replication = layer_replication
self.layers_pattern = layers_pattern
self.layers_to_transform = layers_to_transform
self.loftq_config = loftq_config or {}
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.megatron_config = megatron_config
self.megatron_core = megatron_core
self.modules_to_save = modules_to_save
self.peft_type = peft_type
self.r = r
self.rank_pattern = rank_pattern or {}
self.revision = revision
self.target_modules = target_modules or ["q_proj", "v_proj"]
self.task_type = task_type
self.use_dora = use_dora
self.use_rslora = use_rslora
self.config_path = config_path
@classmethod
def from_file(cls, filename):
with open(filename, "r") as f:
json_data = json.load(f)
return cls(**json_data, config_path=filename)
# TODO: support fetching the model from the hub if it's not in the cache
@classmethod
def from_pretrained(cls, adapter_id, revision=None):
d = hub._get_cached_revision_directory(adapter_id, revision)
filename = os.path.join(d, "adapter_config.json")
return cls.from_file(filename)