chore: formatting

This commit is contained in:
OlivierDehaene 2023-12-11 14:49:52 +01:00
parent 3a521c92b3
commit 72ee382ded
36 changed files with 715 additions and 450 deletions

View File

@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
def serialize( def serialize(
self, self,
data, data,
@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id prefill_token.id == other.id
and prefill_token.text == other.text and prefill_token.text == other.text
and ( and (
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol) math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
)
if prefill_token.logprob is not None if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob else prefill_token.logprob == other.logprob
) )
@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
# Needed for GPTQ with exllama which has serious numerical fluctuations. # Needed for GPTQ with exllama which has serious numerical fluctuations.
rtol = 0.75 rtol = 0.75
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}")
@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle):
def response_snapshot(snapshot): def response_snapshot(snapshot):
return snapshot.use_extension(ResponseComparator) return snapshot.use_extension(ResponseComparator)
@pytest.fixture @pytest.fixture
def generous_response_snapshot(snapshot): def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator) return snapshot.use_extension(GenerousResponseComparator)
@ -219,7 +224,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
dtype: Optional[str] = None dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -282,7 +287,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
dtype: Optional[str] = None dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -335,7 +340,7 @@ def launcher(event_loop):
], ],
volumes=volumes, volumes=volumes,
ports={"80/tcp": port}, ports={"80/tcp": port},
shm_size="1G" shm_size="1G",
) )
yield ContainerLauncherHandle(client, container.name, port) yield ContainerLauncherHandle(client, container.name, port)

View File

@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4) responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert all(
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning' [r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert (
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
)
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1" assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def idefics_handle(launcher): def idefics_handle(launcher):
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle: with launcher(
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
) as handle:
yield handle yield handle

View File

@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0

View File

@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0

View File

@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0

View File

@ -77,12 +77,24 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}: if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
}:
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
) )
@ -140,12 +152,17 @@ def download_weights(
try: try:
import json import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt"
)
if auto_convert: if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors") medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists(): if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], []) utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json") medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
@ -153,10 +170,17 @@ def download_weights(
revision = "main" revision = "main"
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.") logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return return
# Local files not found # Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass

View File

@ -88,7 +88,6 @@ if MIXTRAL:
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -157,7 +156,9 @@ def get_model(
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None: if speculate is not None:
if speculate > speculate_medusa: if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match") raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else: else:
set_speculate(speculate) set_speculate(speculate)
else: else:
@ -249,7 +250,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@ -313,7 +314,9 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(
@ -354,7 +357,7 @@ def get_model(
raise ValueError("awq quantization is not supported for AutoModel") raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel") raise ValueError("4bit quantization is not supported for AutoModel")
elif (quantize == "eetq"): elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel") raise ValueError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(

View File

@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)

View File

@ -510,7 +510,11 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes": if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda() model = model.cuda()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
@ -676,7 +680,10 @@ class CausalLM(Model):
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None

View File

@ -34,9 +34,10 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm FastRMSNorm,
) )
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module):
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,

View File

@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
attention,
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -35,7 +39,7 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm FastRMSNorm,
) )
@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)

View File

@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
FastRMSNorm, FastRMSNorm,
@ -173,9 +176,11 @@ def _load_experts(config, prefix, mat, weights):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype, dtype=weights.dtype,
device=weights.device) device=weights.device,
)
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
@ -184,7 +189,9 @@ def _load_experts(config, prefix, mat, weights):
expert_slice = slice_[:, start:stop].t().contiguous() expert_slice = slice_[:, start:stop].t().contiguous()
else: else:
expert_slice = slice_[start:stop] expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor return tensor
@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
self.softmax_scale = self.head_size ** -0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module):
# Indices for the sparse matrix. The indices for # Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending # the intermediate matrix are dynamic depending
# on the mapping of tokens to experts. # on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows, column_indices = ops.topology(
blocks_per_row) padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory. # For now, use meta init to save the device memory.
data = torch.empty( data = torch.empty(
@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module):
# position of each bin. # position of each bin.
# List of size num_experts # List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...] # padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token # Cumulative selected experts per token
@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module):
# Permute tokens and pad to prepare expert computation # Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim) # (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
self.top_k)
# Create the sparse matrix topology # Create the sparse matrix topology
with torch.no_grad(): with torch.no_grad():
@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts) # (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix( x = stk.Matrix(
topo.size(), topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data) * self.act(stk.ops.sdd(x, self.w1, topo).data)
stk.ops.sdd(x, self.w3, topo).data, * stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices, topo.row_indices,
topo.column_indices, topo.column_indices,
topo.offsets, topo.offsets,
@ -537,7 +543,9 @@ class MixtralLayer(nn.Module):
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) self.block_sparse_moe = BlockSparseMoE(
f"{prefix}.block_sparse_moe", config, weights
)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps

View File

@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = image_url_or_urls image = image_url_or_urls
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)) response = requests.get(
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
)
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content
elif image.startswith("data:"): elif image.startswith("data:"):

View File

@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops from vllm import layernorm_ops
@dataclass @dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast): class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None image_hidden_states: Optional[torch.FloatTensor] = None
@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module):
return out return out
else: else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
# this was adapted from LlamaMLP # this was adapted from LlamaMLP
@ -613,7 +616,12 @@ class IdeficsAttention(nn.Module):
query_shape = query_states.shape query_shape = query_states.shape
key_shape = key_states.shape key_shape = key_states.shape
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin) self.rotary_emb(
query_states.view(-1, *query_shape[2:]),
key_states.reshape(-1, *key_shape[2:]),
cos,
sin,
)
query_states = query_states.view(query_shape) query_states = query_states.view(query_shape)
key_states = key_states.view(key_shape) key_states = key_states.view(key_shape)

View File

@ -112,6 +112,7 @@ def is_url(string):
result = urlparse(string) result = urlparse(string)
return all([result.scheme, result.netloc]) return all([result.scheme, result.netloc])
def is_image(string): def is_image(string):
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
invalidated the url""" invalidated the url"""
@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin):
image_objects = self.image_processor(image_objects, transform=transform) image_objects = self.image_processor(image_objects, transform=transform)
text_encoding = self.tokenizer( text_encoding = self.tokenizer(
text=full_text, text=full_text,
add_special_tokens=False, add_special_tokens=False,

View File

@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length) max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids speculative_ids=speculative_ids,
) )
def __del__(self): def __del__(self):
@ -727,43 +735,54 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32) arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length max_s = max_s + speculative_length
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -808,20 +827,31 @@ class FlashCausalLM(Model):
else: else:
speculative_logits = None speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None: if speculative_logits is not None:
speculative_logits = ( speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
) )
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( (
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
get_speculate(),
batch.speculative_ids,
speculative_logits,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
@ -851,11 +881,7 @@ class FlashCausalLM(Model):
stopped = True stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync # one, we need to first do a GPU <-> CPU sync
@ -863,11 +889,7 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, ( for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
@ -901,7 +923,6 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
@ -983,8 +1004,10 @@ class FlashCausalLM(Model):
current_stopped = False current_stopped = False
stopped = stopped and current_stopped stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left] _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] _next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids index += n_accepted_ids
# Shard generations # Shard generations
@ -1027,7 +1050,10 @@ class FlashCausalLM(Model):
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None

View File

@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM):
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") medusa_head = hf_hub_download(
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" use_medusa, revision=revision, filename="medusa_lm_head.pt"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) )
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head) model.lm_head = MedusaModel(config, weights, lm_head)

View File

@ -104,7 +104,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate:] tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length) max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
@ -273,7 +275,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
speculative_ids=None speculative_ids=None,
) )
@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32) arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length max_s = max_s + speculative_length
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral):
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code,
) )

View File

@ -3,7 +3,10 @@ import torch
from typing import Optional from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral): class FlashMixtral(BaseFlashMistral):
@ -22,5 +25,5 @@ class FlashMixtral(BaseFlashMistral):
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code,
) )

View File

@ -792,7 +792,10 @@ class IdeficsCausalLM(Model):
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None

View File

@ -56,7 +56,7 @@ class Model(ABC):
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate speculate=self.speculate,
) )
@property @property

View File

@ -736,7 +736,7 @@ class Seq2SeqLM(Model):
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
[self.tokenizer.bos_token], [self.tokenizer.bos_token],
[False] [False],
) )
else: else:
prefill_tokens = None prefill_tokens = None

View File

@ -66,7 +66,10 @@ class Tokens:
def to_pb(self) -> generate_pb2.Tokens: def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens( return generate_pb2.Tokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
) )
def __len__(self): def __len__(self):

View File

@ -159,7 +159,13 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
@ -207,5 +213,7 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code) serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
)
) )

View File

@ -51,7 +51,9 @@ except ImportError as e:
) from e ) from e
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx): if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError( raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
) )
@ -91,7 +93,9 @@ def attention(
) )
elif HAS_FLASH_ATTN_V2_ROCM: elif HAS_FLASH_ATTN_V2_ROCM:
if window_size_left != -1: if window_size_left != -1:
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).") raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments. # RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(

View File

@ -11,20 +11,22 @@ logger = getLogger(__name__)
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError: except ImportError:
logger.error('exllamav2_kernels not installed.') logger.error("exllamav2_kernels not installed.")
raise raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4""" """Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,) output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device) output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
gemm_half_q_half(x, q_handle, output, force_cuda) gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape) return output.view(output_shape)
def ext_make_q_matrix(w: dict, temp_dq, key: str = None): def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
""" """
Create Q matrix Create Q matrix
@ -35,7 +37,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale_max"] /= 256 w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short() w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short() w["q_invperm"] = w["q_invperm"].short()
return make_q_matrix(w["q_weight"], return make_q_matrix(
w["q_weight"],
w["q_perm"], w["q_perm"],
w["q_invperm"], w["q_invperm"],
w["q_scale"], w["q_scale"],
@ -44,7 +47,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
temp_dq) temp_dq,
)
# GPTQ # GPTQ
elif "qweight" in w: elif "qweight" in w:
if w["scales"].dtype == torch.float: if w["scales"].dtype == torch.float:
@ -52,10 +56,15 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
# GPTQ with g_idx (act_order) # GPTQ with g_idx (act_order)
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device) w["q_perm"] = torch.empty(
(w["qweight"].shape[0] * 8,),
dtype=torch.short,
device=w["qweight"].device,
)
w["q_invperm"] = torch.empty_like(w["q_perm"]) w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(w["qweight"], return make_q_matrix(
w["qweight"],
w["q_perm"], w["q_perm"],
w["q_invperm"], w["q_invperm"],
none_tensor, none_tensor,
@ -64,10 +73,12 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["qzeros"], w["qzeros"],
w["scales"], w["scales"],
w["g_idx"].cpu(), w["g_idx"].cpu(),
temp_dq) temp_dq,
)
# GPTQ without g_idx # GPTQ without g_idx
else: else:
return make_q_matrix(w["qweight"], return make_q_matrix(
w["qweight"],
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
@ -76,7 +87,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["qzeros"], w["qzeros"],
w["scales"], w["scales"],
none_tensor, none_tensor,
temp_dq) temp_dq,
)
DEVICE = None DEVICE = None
FIXED_BYTES = 0 FIXED_BYTES = 0
@ -106,14 +119,15 @@ class QuantLinear(nn.Module):
super().__init__() super().__init__()
if bits != 4: if bits != 4:
raise ValueError( raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
)
self.q_handle = None self.q_handle = None
self.q_tensors = None self.q_tensors = None
self.bits = bits self.bits = bits
self.maxq = 2 ** self.bits - 1 self.maxq = 2**self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32 self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1] self.outfeatures = qweight.shape[1]
self.padding = - self.outfeatures % 32 self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding self.outfeatures = self.outfeatures + self.padding
self.device = qweight.device self.device = qweight.device
@ -128,9 +142,12 @@ class QuantLinear(nn.Module):
outfeatures = self.outfeatures outfeatures = self.outfeatures
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
assert infeatures % self.group_size == 0 assert infeatures % self.group_size == 0
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits) assert qzeros.shape == (
infeatures // self.group_size,
outfeatures // 32 * self.bits,
)
assert scales.shape == (infeatures // self.group_size, outfeatures) assert scales.shape == (infeatures // self.group_size, outfeatures)
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}" assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}"
global FIXED_BYTES, LAYERS global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
@ -140,17 +157,15 @@ class QuantLinear(nn.Module):
assert self.qweight.device.type == "cuda" assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None assert self.qweight.device.index is not None
self.q_tensors = { self.q_tensors = {
"qweight":self.qweight, "qweight": self.qweight,
"qzeros":self.qzeros, "qzeros": self.qzeros,
"scales":self.scales, "scales": self.scales,
"g_idx":self.g_idx "g_idx": self.g_idx,
} }
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
self.q_handle = ext_make_q_matrix( self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
self.q_tensors, temp_dq
)
def forward(self, x, force_cuda = False): def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None: if self.bias is not None:
@ -179,11 +194,14 @@ class ExLlamaV2DeviceTensors:
self.scratch_bytes = scratch_bytes self.scratch_bytes = scratch_bytes
def prepare(self): def prepare(self):
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) self.scratch = torch.empty(
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
)
def get_scratch_slice(self, size_bytes): def get_scratch_slice(self, size_bytes):
if self.scratch is None: self.prepare() if self.scratch is None:
self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128 size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2 size_half = size_bytes // 2

View File

@ -35,7 +35,9 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding") logger.warning(
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
V2 = False V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
@ -43,14 +45,16 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
elif CAN_EXLLAMA: elif CAN_EXLLAMA:
try: try:
if V2: if V2:
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
HAS_EXLLAMA = "2" HAS_EXLLAMA = "2"
else: else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
@ -325,7 +329,9 @@ def get_linear(weight, bias, quantize):
) )
if use_exllama: if use_exllama:
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) linear = ExllamaQuantLinear(
qweight, qzeros, scales, g_idx, bias, bits, groupsize
)
else: else:
linear = QuantLinear( linear = QuantLinear(
qweight, qweight,
@ -533,7 +539,6 @@ try:
else: else:
dropout_layer_norm = None dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
@ -569,7 +574,6 @@ try:
return normed_hidden_states, residual return normed_hidden_states, residual
class FastRMSNorm(nn.Module): class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float): def __init__(self, weight: torch.Tensor, eps: float):
super().__init__() super().__init__()
@ -601,7 +605,11 @@ try:
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM: elif IS_CUDA_SYSTEM:
# faster post attention rms norm # faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( (
normed_hidden_states,
res,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states, hidden_states,
residual, residual,
self.weight, self.weight,
@ -638,7 +646,8 @@ try:
return out, residual return out, residual
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
except ImportError: except ImportError:
pass pass
@ -650,14 +659,12 @@ try:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
) )
return inv_freq return inv_freq
def _get_rope_config(config): def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None: if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = { rope_scaling = {
@ -667,7 +674,6 @@ try:
return rope_scaling return rope_scaling
return getattr(config, "rope_scaling", None) return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor): def __init__(self, inv_freq, scaling_factor):
super().__init__() super().__init__()
@ -680,17 +686,23 @@ try:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.dynamic_args = None self.dynamic_args = None
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim] q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim: 2 * rotary_dim] q2 = query[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim] k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim: 2 * rotary_dim] k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
@ -700,17 +712,11 @@ try:
head_size = query.shape[-1] head_size = query.shape[-1]
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding( pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
query,
key,
head_size,
cos,
sin,
True
)
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
@ -732,15 +738,16 @@ try:
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"], max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0, base=10000.0,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1 beta_slow=1,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -773,15 +780,16 @@ try:
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"], max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0, base=10000.0,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1 beta_slow=1,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -827,7 +835,6 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
@ -861,24 +868,28 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
import math import math
def find_correction_dim(
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): num_rotations, dim, base=10000, max_position_embeddings=2048
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) ):
return (
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
) / (2 * math.log(base))
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): def find_correction_range(
low = math.floor(find_correction_dim( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
low_rot, dim, base, max_position_embeddings)) ):
high = math.ceil(find_correction_dim( low = math.floor(
high_rot, dim, base, max_position_embeddings)) find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim): def linear_ramp_mask(min, max, dim):
if min == max: if min == max:
max += 0.001 # Prevent singularity max += 0.001 # Prevent singularity
@ -887,16 +898,25 @@ try:
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale=1):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor, def __init__(
attn_factor, beta_fast, beta_slow): self,
dim,
max_position_embeddings,
base,
device,
scaling_factor,
*,
extrapolation_factor,
attn_factor,
beta_fast,
beta_slow,
):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
self.dim = dim self.dim = dim
@ -906,8 +926,9 @@ try:
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale = float(get_mscale( self.mscale = float(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -923,15 +944,26 @@ try:
) )
freqs = 1.0 / inv_freq_extrapolation freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, low, high = find_correction_range(
self.max_position_embeddings) self.beta_fast,
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to( self.beta_slow,
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation self.dim,
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask self.base,
self.max_position_embeddings,
)
inv_freq_mask = (
1
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(get_mscale( self.mscale = float(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)

View File

@ -2,6 +2,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass @dataclass
class Output: class Output:
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
@ -11,7 +12,9 @@ class Output:
class ResBlock(torch.nn.Module): class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True) self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU() self.act = torch.nn.SiLU()
def forward(self, x): def forward(self, x):
@ -19,15 +22,13 @@ class ResBlock(torch.nn.Module):
class MedusaModel(torch.nn.Module): class MedusaModel(torch.nn.Module):
def __init__( def __init__(self, config, weights, lm_head):
self,
config,
weights,
lm_head
):
super().__init__() super().__init__()
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] [
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
) )
self.lm_head = lm_head self.lm_head = lm_head
@ -40,9 +41,16 @@ class MedusaModel(torch.nn.Module):
class MedusaHead(torch.nn.Module): class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])]) self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks) n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False) self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:

View File

@ -7,11 +7,14 @@ from vllm import attention_ops
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, def reshape_and_cache(
slots: torch.Tensor): key: torch.Tensor,
cache_ops.reshape_and_cache( value: torch.Tensor,
key, value, key_cache, value_cache, slots key_cache: torch.Tensor,
) value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
def attention( def attention(
@ -45,9 +48,7 @@ def attention(
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = ( max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
(max_s + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of

View File

@ -38,7 +38,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
os.makedirs(model_id, exist_ok=True) os.makedirs(model_id, exist_ok=True)
cache_dir = model_id cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}") logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(
base_model_id, trust_remote_code=trust_remote_code
)
model.save_pretrained(cache_dir, safe_serialization=True) model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir) model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)

View File

@ -1,12 +1,11 @@
SPECULATE = None SPECULATE = None
def get_speculate() -> int: def get_speculate() -> int:
global SPECULATE global SPECULATE
return SPECULATE return SPECULATE
def set_speculate(speculate: int): def set_speculate(speculate: int):
global SPECULATE global SPECULATE
SPECULATE = speculate SPECULATE = speculate

View File

@ -16,6 +16,7 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
@ -145,21 +146,31 @@ class StoppingCriteria:
pb.ignore_eos_token, pb.ignore_eos_token,
) )
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
def create_n_gram_speculation(
input_ids: torch.Tensor,
next_ids: torch.Tensor,
accepted_ids: torch.Tensor,
speculate: int,
verbose: bool,
):
# Very trivial approach, find first match in the string. # Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work # This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with # relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device. # much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0] B = accepted_ids.shape[0]
device = input_ids.device device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
speculate, device=device
)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices) speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids return speculative_ids
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
self, self,
@ -228,7 +239,15 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
speculate: int,
speculated_ids: Optional[torch.Tensor] = None,
speculative_scores: Optional[torch.Tensor] = None,
verbose=False,
):
if speculated_ids is not None: if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1) B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1 S = speculated_ids.shape[1] + 1
@ -249,12 +268,11 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S) next_ids = next_ids.view(B * S)
scores = scores.view( B* S, -1) scores = scores.view(B * S, -1)
if speculated_ids is not None: if speculated_ids is not None:
accepted_ids = [] accepted_ids = []
@ -262,7 +280,7 @@ class HeterogeneousNextTokenChooser:
S = speculated_ids.shape[1] + 1 S = speculated_ids.shape[1] + 1
indices = [] indices = []
for i in range(B): for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S] _next_ids = next_ids[i * S : (i + 1) * S]
_speculated_ids = speculated_ids[i] _speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S index = i * S
@ -278,7 +296,9 @@ class HeterogeneousNextTokenChooser:
break break
accepted_ids.append(accepted) accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) accepted_ids = torch.tensor(
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices] next_ids = next_ids[indices]
scores = scores[indices] scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S indices = torch.arange(B, device=input_ids.device) * S
@ -296,7 +316,9 @@ class HeterogeneousNextTokenChooser:
speculative_ids = Greedy()(speculative_scores) speculative_ids = Greedy()(speculative_scores)
else: else:
# n-gram # n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) speculative_ids = create_n_gram_speculation(
input_ids, next_ids, accepted_ids, speculate, verbose
)
else: else:
speculative_ids = None speculative_ids = None

View File

@ -16,7 +16,7 @@ class Weights:
dtype, dtype,
process_group, process_group,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None prefix: Optional[str] = None,
): ):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
@ -213,7 +213,8 @@ class Weights:
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -283,7 +284,7 @@ class Weights:
if use_exllama: if use_exllama:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
else: else:
# The triton kernel reorders the scales/zero points instead of the weight/activation. # The triton kernel reorders the scales/zero points instead of the weight/activation.

View File

@ -21,14 +21,14 @@ def main():
block = [] block = []
for line in lines: for line in lines:
if line.startswith(" -") or line.startswith(" -"): if line.startswith(" -") or line.startswith(" -"):
rendered_block = '\n'.join(block) rendered_block = "\n".join(block)
if header: if header:
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
else: else:
final_doc += f"```shell\n{rendered_block}\n```\n" final_doc += f"```shell\n{rendered_block}\n```\n"
block = [] block = []
tokens = line.split("<") tokens = line.split("<")
if len(tokens)>1: if len(tokens) > 1:
header = tokens[-1][:-1] header = tokens[-1][:-1]
else: else:
header = line.split("--")[-1] header = line.split("--")[-1]
@ -36,7 +36,7 @@ def main():
block.append(line) block.append(line)
rendered_block = '\n'.join(block) rendered_block = "\n".join(block)
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
block = [] block = []