feat: support base model generation and refactors

This commit is contained in:
drbh 2024-06-07 01:20:41 +00:00
parent 43ec9dfe32
commit 611225f017
8 changed files with 50 additions and 165 deletions

View File

@ -8,7 +8,6 @@ from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap from text_generation_server.adapters.config import AdapterConfig, ModuleMap
LORA = "lora"
from text_generation_server.adapters.weights import ( from text_generation_server.adapters.weights import (
AdapterBatchMetadata, AdapterBatchMetadata,
AdapterWeights, AdapterWeights,
@ -246,7 +245,7 @@ class BatchLoraWeights(BatchAdapterWeights):
@classmethod @classmethod
def key(cls) -> str: def key(cls) -> str:
return LORA return "lora"
@classmethod @classmethod
def load( def load(
@ -279,9 +278,12 @@ class BatchLoraWeights(BatchAdapterWeights):
} }
max_rank = max( max_rank = max(
adapter_weights[idx].lora_a_r (
for idx in segment_indices adapter_weights[idx].lora_a_r
if idx in adapter_weights for idx in segment_indices
if idx in adapter_weights
),
default=0,
) )
if prefill or max_rank > BGMV_MAX_RANK: if prefill or max_rank > BGMV_MAX_RANK:

View File

@ -1,4 +1,3 @@
#############
from abc import ABC, abstractclassmethod from abc import ABC, abstractclassmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
@ -7,10 +6,6 @@ from typing import Dict, List, Optional, Set, Type
import torch import torch
LORA = "lora"
LM_HEAD = "lm_head"
@dataclass @dataclass
class AdapterBatchMetadata: class AdapterBatchMetadata:
# [batch_size] # [batch_size]
@ -127,7 +122,7 @@ class AdapterBatchData:
if v.is_empty(): if v.is_empty():
continue continue
data[k] = v.get_data( data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == LM_HEAD else None meta, prefill, prefill_head_indices if k == "lm_head" else None
) )
return AdapterBatchData(meta=meta, data=data, prefill=prefill) return AdapterBatchData(meta=meta, data=data, prefill=prefill)
@ -135,7 +130,7 @@ class AdapterBatchData:
# TODO(travis): refactor to be less coupled to lora implementation # TODO(travis): refactor to be less coupled to lora implementation
ranks = set() ranks = set()
for layer_data in self.data.values(): for layer_data in self.data.values():
lora_data = layer_data.get(LORA) lora_data = layer_data.get("lora")
if lora_data is None: if lora_data is None:
continue continue

View File

@ -78,7 +78,6 @@ def serve(
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
# TODO: determine if this api makes sense
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace # split on comma and strip whitespace

View File

@ -52,16 +52,6 @@ if SYSTEM == "rocm":
except Exception as e: except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
# Constants
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
def load_attention(config, prefix, weights, layer_id): def load_attention(config, prefix, weights, layer_id):
# Only defined in granite. # Only defined in granite.
@ -100,7 +90,7 @@ def load_attention(config, prefix, weights, layer_id):
return TensorParallelMultiAdapterLinear.load( return TensorParallelMultiAdapterLinear.load(
base_layer, base_layer,
layer_id, layer_id,
[Q_PROJ, K_PROJ, V_PROJ], ["q_proj", "k_proj", "v_proj"],
sizes=[ sizes=[
head_size * config.num_attention_heads, head_size * config.num_attention_heads,
head_size * config.num_key_value_heads, head_size * config.num_key_value_heads,
@ -160,7 +150,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.o_proj = TensorParallelAdapterRowLinear.load( self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj, o_proj,
index, index,
O_PROJ, "o_proj",
process_group=weights.process_group, process_group=weights.process_group,
) )
@ -268,7 +258,7 @@ class LlamaMLP(nn.Module):
self.gate_up_proj = TensorParallelMultiAdapterLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj, gate_up_proj,
index, index,
[GATE_PROJ, UP_PROJ], ["gate_proj", "up_proj"],
sizes=[ sizes=[
config.intermediate_size, config.intermediate_size,
config.intermediate_size, config.intermediate_size,
@ -286,7 +276,7 @@ class LlamaMLP(nn.Module):
self.down_proj = TensorParallelAdapterRowLinear.load( self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj, down_proj,
index, index,
DOWN_PROJ, "down_proj",
process_group=weights.process_group, process_group=weights.process_group,
) )

View File

@ -20,31 +20,17 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.lora import LoraConfig
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
LM_HEAD = "lm_head"
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [ ADAPTER_LAYERS = [
Q_PROJ, "q_proj",
K_PROJ, "k_proj",
V_PROJ, "v_proj",
O_PROJ, "o_proj",
GATE_PROJ, "gate_proj",
UP_PROJ, "up_proj",
DOWN_PROJ, "down_proj",
] # LM_HEAD ]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
@ -123,32 +109,32 @@ class FlashLlama(FlashCausalLM):
prefix = "model.layers" prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers): for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = ( layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj", f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
) )
layer_weights[(i, K_PROJ)] = ( layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj", f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
) )
layer_weights[(i, V_PROJ)] = ( layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj", f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
) )
layer_weights[(i, O_PROJ)] = ( layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj", f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj, layer.self_attn.o_proj,
) )
layer_weights[(i, GATE_PROJ)] = ( layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj", f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj, layer.mlp.gate_up_proj,
) )
layer_weights[(i, UP_PROJ)] = ( layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj", f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj, layer.mlp.gate_up_proj,
) )
layer_weights[(i, DOWN_PROJ)] = ( layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj", f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj, layer.mlp.down_proj,
) )
@ -162,7 +148,7 @@ class FlashLlama(FlashCausalLM):
@property @property
def default_traced_adapter_layers(self) -> List[str]: def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ] return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int: def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers) return 1 if layer_type == LM_HEAD else len(self.model.model.layers)

View File

@ -21,29 +21,16 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
LM_HEAD = "lm_head"
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [ ADAPTER_LAYERS = [
Q_PROJ, "q_proj",
K_PROJ, "k_proj",
V_PROJ, "v_proj",
O_PROJ, "o_proj",
GATE_PROJ, "gate_proj",
UP_PROJ, "up_proj",
DOWN_PROJ, "down_proj",
] # LM_HEAD ]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
@ -133,37 +120,37 @@ class BaseFlashMistral(FlashCausalLM):
prefix = "model.layers" prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers): for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = ( layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj", f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
) )
layer_weights[(i, K_PROJ)] = ( layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj", f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
) )
layer_weights[(i, V_PROJ)] = ( layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj", f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
) )
layer_weights[(i, O_PROJ)] = ( layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj", f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj, layer.self_attn.o_proj,
) )
layer_weights[(i, GATE_PROJ)] = ( layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj", f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj, layer.mlp.gate_up_proj,
) )
layer_weights[(i, UP_PROJ)] = ( layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj", f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj, layer.mlp.gate_up_proj,
) )
layer_weights[(i, DOWN_PROJ)] = ( layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj", f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj, layer.mlp.down_proj,
) )
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
return layer_weights return layer_weights
@property @property
@ -172,10 +159,10 @@ class BaseFlashMistral(FlashCausalLM):
@property @property
def default_traced_adapter_layers(self) -> List[str]: def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ] return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int: def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers) return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool: def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL return layer_type in ROW_PARALLEL

View File

@ -251,7 +251,7 @@ def serve(
density=1.0, density=1.0,
majority_sign_method=0, majority_sign_method=0,
) )
adapter_index = index adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index adapter_to_index[adapter_id] = adapter_index
model.load_adapter( model.load_adapter(
adapter_parameters, adapter_parameters,

View File

@ -1,74 +0,0 @@
import json
from text_generation_server.utils import (
hub,
)
import os
class LoraConfig:
def __init__(
self,
alpha_pattern=None,
auto_mapping=None,
base_model_name_or_path="",
bias="none",
fan_in_fan_out=False,
inference_mode=True,
init_lora_weights=True,
layer_replication=None,
layers_pattern=None,
layers_to_transform=None,
loftq_config=None,
lora_alpha=16,
lora_dropout=0.1,
megatron_config=None,
megatron_core="megatron.core",
modules_to_save=None,
peft_type="LORA",
r=8,
rank_pattern=None,
revision=None,
target_modules=None,
task_type="CAUSAL_LM",
use_dora=False,
use_rslora=False,
config_path=None,
):
self.alpha_pattern = alpha_pattern or {}
self.auto_mapping = auto_mapping
self.base_model_name_or_path = base_model_name_or_path
self.bias = bias
self.fan_in_fan_out = fan_in_fan_out
self.inference_mode = inference_mode
self.init_lora_weights = init_lora_weights
self.layer_replication = layer_replication
self.layers_pattern = layers_pattern
self.layers_to_transform = layers_to_transform
self.loftq_config = loftq_config or {}
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.megatron_config = megatron_config
self.megatron_core = megatron_core
self.modules_to_save = modules_to_save
self.peft_type = peft_type
self.r = r
self.rank_pattern = rank_pattern or {}
self.revision = revision
self.target_modules = target_modules or ["q_proj", "v_proj"]
self.task_type = task_type
self.use_dora = use_dora
self.use_rslora = use_rslora
self.config_path = config_path
@classmethod
def from_file(cls, filename):
with open(filename, "r") as f:
json_data = json.load(f)
return cls(**json_data, config_path=filename)
# TODO: support fetching the model from the hub if it's not in the cache
@classmethod
def from_pretrained(cls, adapter_id, revision=None):
d = hub._get_cached_revision_directory(adapter_id, revision)
filename = os.path.join(d, "adapter_config.json")
return cls.from_file(filename)