chore: formatting

This commit is contained in:
OlivierDehaene 2023-12-11 14:49:52 +01:00
parent 3a521c92b3
commit 72ee382ded
36 changed files with 715 additions and 450 deletions

View File

@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
def serialize( def serialize(
self, self,
data, data,
@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id prefill_token.id == other.id
and prefill_token.text == other.text and prefill_token.text == other.text
and ( and (
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol) math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
)
if prefill_token.logprob is not None if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob else prefill_token.logprob == other.logprob
) )
@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
# Needed for GPTQ with exllama which has serious numerical fluctuations. # Needed for GPTQ with exllama which has serious numerical fluctuations.
rtol = 0.75 rtol = 0.75
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}")
@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle):
def response_snapshot(snapshot): def response_snapshot(snapshot):
return snapshot.use_extension(ResponseComparator) return snapshot.use_extension(ResponseComparator)
@pytest.fixture @pytest.fixture
def generous_response_snapshot(snapshot): def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator) return snapshot.use_extension(GenerousResponseComparator)
@ -219,7 +224,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
dtype: Optional[str] = None dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -282,7 +287,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
dtype: Optional[str] = None dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -335,7 +340,7 @@ def launcher(event_loop):
], ],
volumes=volumes, volumes=volumes,
ports={"80/tcp": port}, ports={"80/tcp": port},
shm_size="1G" shm_size="1G",
) )
yield ContainerLauncherHandle(client, container.name, port) yield ContainerLauncherHandle(client, container.name, port)

View File

@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4) responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert all(
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning' [r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert (
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
)
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1" assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def idefics_handle(launcher): def idefics_handle(launcher):
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle: with launcher(
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
) as handle:
yield handle yield handle

View File

@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0

View File

@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0

View File

@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0

View File

@ -77,12 +77,24 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}: if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
}:
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
) )
@ -140,12 +152,17 @@ def download_weights(
try: try:
import json import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt"
)
if auto_convert: if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors") medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists(): if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], []) utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json") medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
@ -153,10 +170,17 @@ def download_weights(
revision = "main" revision = "main"
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.") logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return return
# Local files not found # Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass

View File

@ -88,7 +88,6 @@ if MIXTRAL:
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -157,7 +156,9 @@ def get_model(
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None: if speculate is not None:
if speculate > speculate_medusa: if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match") raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else: else:
set_speculate(speculate) set_speculate(speculate)
else: else:
@ -249,7 +250,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@ -313,7 +314,9 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(
@ -354,7 +357,7 @@ def get_model(
raise ValueError("awq quantization is not supported for AutoModel") raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel") raise ValueError("4bit quantization is not supported for AutoModel")
elif (quantize == "eetq"): elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel") raise ValueError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(

View File

@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)

View File

@ -510,7 +510,11 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes": if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda() model = model.cuda()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
@ -676,7 +680,10 @@ class CausalLM(Model):
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -703,11 +710,11 @@ class CausalLM(Model):
request.id, request.id,
prefill_tokens, prefill_tokens,
Tokens( Tokens(
[next_token_id_squeezed], [next_token_id_squeezed],
[next_token_logprob], [next_token_logprob],
[next_token_text], [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids], [next_token_id_squeezed.item() in self.all_special_ids],
), ),
generated_text, generated_text,
top_tokens, top_tokens,
) )

View File

@ -34,9 +34,10 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm FastRMSNorm,
) )
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
@ -202,7 +203,7 @@ class FlashLlamaAttention(torch.nn.Module):
) )
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
@ -237,7 +238,7 @@ class FlashLlamaAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module):
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,

View File

@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
attention,
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -35,7 +39,7 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm FastRMSNorm,
) )
@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)

View File

@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
FastRMSNorm, FastRMSNorm,
@ -59,28 +62,28 @@ class MixtralConfig(PretrainedConfig):
model_type = "mixtral" model_type = "mixtral"
def __init__( def __init__(
self, self,
vocab_size=32000, vocab_size=32000,
hidden_size=4096, hidden_size=4096,
intermediate_size=14336, intermediate_size=14336,
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=8, num_key_value_heads=8,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096 * 32, max_position_embeddings=4096 * 32,
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-05, rms_norm_eps=1e-05,
use_cache=True, use_cache=True,
pad_token_id=None, pad_token_id=None,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
sliding_window=4096, sliding_window=4096,
num_experts_per_tok=2, num_experts_per_tok=2,
num_local_experts=8, num_local_experts=8,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
@ -166,16 +169,18 @@ def _load_experts(config, prefix, mat, weights):
rank = weights.process_group.rank() rank = weights.process_group.rank()
assert ( assert (
config.intermediate_size % world_size == 0 config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size block_size = config.intermediate_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), tensor = torch.empty(
dtype=weights.dtype, (config.num_local_experts * block_size, config.hidden_size),
device=weights.device) dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
@ -184,16 +189,18 @@ def _load_experts(config, prefix, mat, weights):
expert_slice = slice_[:, start:stop].t().contiguous() expert_slice = slice_[:, start:stop].t().contiguous()
else: else:
expert_slice = slice_[start:stop] expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor return tensor
class MixtralAttention(torch.nn.Module): class MixtralAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
prefix: str, prefix: str,
config, config,
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
self.softmax_scale = self.head_size ** -0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -219,7 +226,7 @@ class MixtralAttention(torch.nn.Module):
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
@ -236,17 +243,17 @@ class MixtralAttention(torch.nn.Module):
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module):
# Indices for the sparse matrix. The indices for # Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending # the intermediate matrix are dynamic depending
# on the mapping of tokens to experts. # on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows, column_indices = ops.topology(
blocks_per_row) padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory. # For now, use meta init to save the device memory.
data = torch.empty( data = torch.empty(
@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module):
# position of each bin. # position of each bin.
# List of size num_experts # List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...] # padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token # Cumulative selected experts per token
@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module):
# Permute tokens and pad to prepare expert computation # Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim) # (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
self.top_k)
# Create the sparse matrix topology # Create the sparse matrix topology
with torch.no_grad(): with torch.no_grad():
@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts) # (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix( x = stk.Matrix(
topo.size(), topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data) * self.act(stk.ops.sdd(x, self.w1, topo).data)
stk.ops.sdd(x, self.w3, topo).data, * stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices, topo.row_indices,
topo.column_indices, topo.column_indices,
topo.offsets, topo.offsets,
@ -537,7 +543,9 @@ class MixtralLayer(nn.Module):
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) self.block_sparse_moe = BlockSparseMoE(
f"{prefix}.block_sparse_moe", config, weights
)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -549,18 +557,18 @@ class MixtralLayer(nn.Module):
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -615,16 +623,16 @@ class MixtralModel(torch.nn.Module):
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
@ -670,17 +678,17 @@ class FlashMixtralForCausalLM(torch.nn.Module):
raise ValueError("max_past cannot be None") raise ValueError("max_past cannot be None")
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor

View File

@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = image_url_or_urls image = image_url_or_urls
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)) response = requests.get(
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
)
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content
elif image.startswith("data:"): elif image.startswith("data:"):
@ -213,7 +215,7 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = Image.open(BytesIO(content)) image = Image.open(BytesIO(content))
# image.verify() # image.verify()
except Exception: except Exception:
raise ValueError(f"Could not load image from url {image_url_or_urls}") raise ValueError(f"Could not load image from url {image_url_or_urls}")
return image return image
else: else:
raise ValueError( raise ValueError(

View File

@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops from vllm import layernorm_ops
@dataclass @dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast): class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None image_hidden_states: Optional[torch.FloatTensor] = None
@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module):
return out return out
else: else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
# this was adapted from LlamaMLP # this was adapted from LlamaMLP
@ -613,8 +616,13 @@ class IdeficsAttention(nn.Module):
query_shape = query_states.shape query_shape = query_states.shape
key_shape = key_states.shape key_shape = key_states.shape
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin) self.rotary_emb(
query_states.view(-1, *query_shape[2:]),
key_states.reshape(-1, *key_shape[2:]),
cos,
sin,
)
query_states = query_states.view(query_shape) query_states = query_states.view(query_shape)
key_states = key_states.view(key_shape) key_states = key_states.view(key_shape)

View File

@ -112,6 +112,7 @@ def is_url(string):
result = urlparse(string) result = urlparse(string)
return all([result.scheme, result.netloc]) return all([result.scheme, result.netloc])
def is_image(string): def is_image(string):
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
invalidated the url""" invalidated the url"""
@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin):
image_objects = self.image_processor(image_objects, transform=transform) image_objects = self.image_processor(image_objects, transform=transform)
text_encoding = self.tokenizer( text_encoding = self.tokenizer(
text=full_text, text=full_text,
add_special_tokens=False, add_special_tokens=False,

View File

@ -11,7 +11,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length) max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids speculative_ids=speculative_ids,
) )
def __del__(self): def __del__(self):
@ -727,43 +735,54 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32) arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length max_s = max_s + speculative_length
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -808,20 +827,31 @@ class FlashCausalLM(Model):
else: else:
speculative_logits = None speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None: if speculative_logits is not None:
speculative_logits = ( speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
) )
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( (
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
get_speculate(),
batch.speculative_ids,
speculative_logits,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
@ -851,11 +881,7 @@ class FlashCausalLM(Model):
stopped = True stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync # one, we need to first do a GPU <-> CPU sync
@ -863,11 +889,7 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, ( for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
@ -901,7 +923,6 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
@ -983,8 +1004,10 @@ class FlashCausalLM(Model):
current_stopped = False current_stopped = False
stopped = stopped and current_stopped stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left] _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] _next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids index += n_accepted_ids
# Shard generations # Shard generations
@ -1027,7 +1050,10 @@ class FlashCausalLM(Model):
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None

View File

@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM):
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") medusa_head = hf_hub_download(
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" use_medusa, revision=revision, filename="medusa_lm_head.pt"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) )
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head) model.lm_head = MedusaModel(config, weights, lm_head)

View File

@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
): ):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate:] tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length) max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
@ -273,20 +275,20 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
speculative_ids=None speculative_ids=None,
) )
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
config_cls, config_cls,
model_cls, model_cls,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32) arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length max_s = max_s + speculative_length
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -401,12 +414,12 @@ class BaseFlashMistral(FlashCausalLM):
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
super(FlashMistral, self).__init__( super(FlashMistral, self).__init__(
config_cls=MistralConfig, config_cls=MistralConfig,
@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral):
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code,
) )

View File

@ -3,17 +3,20 @@ import torch
from typing import Optional from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral): class FlashMixtral(BaseFlashMistral):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
super(FlashMixtral, self).__init__( super(FlashMixtral, self).__init__(
config_cls=MixtralConfig, config_cls=MixtralConfig,
@ -22,5 +25,5 @@ class FlashMixtral(BaseFlashMistral):
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code,
) )

View File

@ -792,7 +792,10 @@ class IdeficsCausalLM(Model):
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -803,10 +806,10 @@ class IdeficsCausalLM(Model):
request.id, request.id,
prefill_tokens, prefill_tokens,
Tokens( Tokens(
[next_token_id_squeezed], [next_token_id_squeezed],
[next_token_logprob], [next_token_logprob],
[next_token_text], [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids], [next_token_id_squeezed.item() in self.all_special_ids],
), ),
generated_text, generated_text,
top_tokens, top_tokens,

View File

@ -56,7 +56,7 @@ class Model(ABC):
dtype=str(self.dtype), dtype=str(self.dtype),
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate speculate=self.speculate,
) )
@property @property

View File

@ -736,7 +736,7 @@ class Seq2SeqLM(Model):
[self.tokenizer.bos_token_id], [self.tokenizer.bos_token_id],
[float("nan")], [float("nan")],
[self.tokenizer.bos_token], [self.tokenizer.bos_token],
[False] [False],
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -763,10 +763,10 @@ class Seq2SeqLM(Model):
request.id, request.id,
prefill_tokens, prefill_tokens,
Tokens( Tokens(
[next_token_id_squeezed], [next_token_id_squeezed],
[next_token_logprob], [next_token_logprob],
[next_token_text], [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids], [next_token_id_squeezed.item() in self.all_special_ids],
), ),
generated_text, generated_text,
top_tokens, top_tokens,

View File

@ -66,7 +66,10 @@ class Tokens:
def to_pb(self) -> generate_pb2.Tokens: def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens( return generate_pb2.Tokens(
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
) )
def __len__(self): def __len__(self):

View File

@ -159,7 +159,13 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
@ -207,5 +213,7 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code) serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
)
) )

View File

@ -51,7 +51,9 @@ except ImportError as e:
) from e ) from e
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx): if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError( raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
) )
@ -91,8 +93,10 @@ def attention(
) )
elif HAS_FLASH_ATTN_V2_ROCM: elif HAS_FLASH_ATTN_V2_ROCM:
if window_size_left != -1: if window_size_left != -1:
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).") raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments. # RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,

View File

@ -11,40 +11,44 @@ logger = getLogger(__name__)
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError: except ImportError:
logger.error('exllamav2_kernels not installed.') logger.error("exllamav2_kernels not installed.")
raise raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4""" """Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,) output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device) output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
gemm_half_q_half(x, q_handle, output, force_cuda) gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape) return output.view(output_shape)
def ext_make_q_matrix(w: dict, temp_dq, key: str = None): def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
""" """
Create Q matrix Create Q matrix
""" """
# EXL2 # EXL2
# won't work as the moment because the tensors are not the same. # won't work as the moment because the tensors are not the same.
if "q_weight" in w: if "q_weight" in w:
w["q_scale_max"] /= 256 w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short() w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short() w["q_invperm"] = w["q_invperm"].short()
return make_q_matrix(w["q_weight"], return make_q_matrix(
w["q_perm"], w["q_weight"],
w["q_invperm"], w["q_perm"],
w["q_scale"], w["q_invperm"],
w["q_scale_max"], w["q_scale"],
w["q_groups"], w["q_scale_max"],
none_tensor, w["q_groups"],
none_tensor, none_tensor,
none_tensor, none_tensor,
temp_dq) none_tensor,
temp_dq,
)
# GPTQ # GPTQ
elif "qweight" in w: elif "qweight" in w:
if w["scales"].dtype == torch.float: if w["scales"].dtype == torch.float:
@ -52,31 +56,40 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
# GPTQ with g_idx (act_order) # GPTQ with g_idx (act_order)
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device) w["q_perm"] = torch.empty(
(w["qweight"].shape[0] * 8,),
dtype=torch.short,
device=w["qweight"].device,
)
w["q_invperm"] = torch.empty_like(w["q_perm"]) w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(w["qweight"], return make_q_matrix(
w["q_perm"], w["qweight"],
w["q_invperm"], w["q_perm"],
none_tensor, w["q_invperm"],
none_tensor, none_tensor,
none_tensor, none_tensor,
w["qzeros"], none_tensor,
w["scales"], w["qzeros"],
w["g_idx"].cpu(), w["scales"],
temp_dq) w["g_idx"].cpu(),
temp_dq,
)
# GPTQ without g_idx # GPTQ without g_idx
else: else:
return make_q_matrix(w["qweight"], return make_q_matrix(
none_tensor, w["qweight"],
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
w["qzeros"], none_tensor,
w["scales"], w["qzeros"],
none_tensor, w["scales"],
temp_dq) none_tensor,
temp_dq,
)
DEVICE = None DEVICE = None
FIXED_BYTES = 0 FIXED_BYTES = 0
@ -106,14 +119,15 @@ class QuantLinear(nn.Module):
super().__init__() super().__init__()
if bits != 4: if bits != 4:
raise ValueError( raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
)
self.q_handle = None self.q_handle = None
self.q_tensors = None self.q_tensors = None
self.bits = bits self.bits = bits
self.maxq = 2 ** self.bits - 1 self.maxq = 2**self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32 self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1] self.outfeatures = qweight.shape[1]
self.padding = - self.outfeatures % 32 self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding self.outfeatures = self.outfeatures + self.padding
self.device = qweight.device self.device = qweight.device
@ -128,9 +142,12 @@ class QuantLinear(nn.Module):
outfeatures = self.outfeatures outfeatures = self.outfeatures
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
assert infeatures % self.group_size == 0 assert infeatures % self.group_size == 0
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits) assert qzeros.shape == (
infeatures // self.group_size,
outfeatures // 32 * self.bits,
)
assert scales.shape == (infeatures // self.group_size, outfeatures) assert scales.shape == (infeatures // self.group_size, outfeatures)
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}" assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}"
global FIXED_BYTES, LAYERS global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
@ -140,33 +157,31 @@ class QuantLinear(nn.Module):
assert self.qweight.device.type == "cuda" assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None assert self.qweight.device.index is not None
self.q_tensors = { self.q_tensors = {
"qweight":self.qweight, "qweight": self.qweight,
"qzeros":self.qzeros, "qzeros": self.qzeros,
"scales":self.scales, "scales": self.scales,
"g_idx":self.g_idx "g_idx": self.g_idx,
} }
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
self.q_handle = ext_make_q_matrix( self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
self.q_tensors, temp_dq
) def forward(self, x, force_cuda=False):
def forward(self, x, force_cuda = False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None: if self.bias is not None:
output.add_(self.bias) output.add_(self.bias)
return output return output
def temp_dq_size(self): def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128 return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size): def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128 return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors: class ExLlamaV2DeviceTensors:
device_idx: int device_idx: int
@ -177,13 +192,16 @@ class ExLlamaV2DeviceTensors:
def __init__(self, device, scratch_bytes): def __init__(self, device, scratch_bytes):
self.device = device self.device = device
self.scratch_bytes = scratch_bytes self.scratch_bytes = scratch_bytes
def prepare(self): def prepare(self):
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device) self.scratch = torch.empty(
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
)
def get_scratch_slice(self, size_bytes): def get_scratch_slice(self, size_bytes):
if self.scratch is None: self.prepare() if self.scratch is None:
self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128 size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2 size_half = size_bytes // 2

View File

@ -35,7 +35,9 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding") logger.warning(
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
)
V2 = False V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
@ -43,17 +45,19 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
elif CAN_EXLLAMA: elif CAN_EXLLAMA:
try: try:
if V2: if V2:
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllamav2 import (
create_exllama_buffers, QuantLinear as ExllamaQuantLinear,
set_device, create_exllama_buffers,
) set_device,
)
HAS_EXLLAMA = "2" HAS_EXLLAMA = "2"
else: else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllama import (
create_exllama_buffers, Ex4bitLinear as ExllamaQuantLinear,
set_device, create_exllama_buffers,
) set_device,
)
HAS_EXLLAMA = "1" HAS_EXLLAMA = "1"
@ -114,7 +118,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@classmethod @classmethod
def load_conv2d_no_bias( def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride cls, prefix, weights, in_channels, out_channels, kernel_size, stride
): ):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights(): with init_empty_weights():
@ -138,9 +142,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLinear(nn.Module): class FastLinear(nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
bias, bias,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
@ -164,9 +168,9 @@ class FastLinear(nn.Module):
class EETQLinear(nn.Module): class EETQLinear(nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
bias, bias,
) -> None: ) -> None:
super().__init__() super().__init__()
device = weight.device device = weight.device
@ -185,13 +189,13 @@ class EETQLinear(nn.Module):
class Linear8bitLt(nn.Module): class Linear8bitLt(nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
bias, bias,
has_fp16_weights=True, has_fp16_weights=True,
memory_efficient_backward=False, memory_efficient_backward=False,
threshold=0.0, threshold=0.0,
index=None, index=None,
): ):
super().__init__() super().__init__()
assert ( assert (
@ -325,7 +329,9 @@ def get_linear(weight, bias, quantize):
) )
if use_exllama: if use_exllama:
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) linear = ExllamaQuantLinear(
qweight, qzeros, scales, g_idx, bias, bits, groupsize
)
else: else:
linear = QuantLinear( linear = QuantLinear(
qweight, qweight,
@ -533,7 +539,6 @@ try:
else: else:
dropout_layer_norm = None dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
@ -569,7 +574,6 @@ try:
return normed_hidden_states, residual return normed_hidden_states, residual
class FastRMSNorm(nn.Module): class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float): def __init__(self, weight: torch.Tensor, eps: float):
super().__init__() super().__init__()
@ -601,7 +605,11 @@ try:
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM: elif IS_CUDA_SYSTEM:
# faster post attention rms norm # faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( (
normed_hidden_states,
res,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states, hidden_states,
residual, residual,
self.weight, self.weight,
@ -638,7 +646,8 @@ try:
return out, residual return out, residual
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
except ImportError: except ImportError:
pass pass
@ -650,14 +659,12 @@ try:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
) )
return inv_freq return inv_freq
def _get_rope_config(config): def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None: if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = { rope_scaling = {
@ -667,7 +674,6 @@ try:
return rope_scaling return rope_scaling
return getattr(config, "rope_scaling", None) return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor): def __init__(self, inv_freq, scaling_factor):
super().__init__() super().__init__()
@ -680,17 +686,23 @@ try:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.dynamic_args = None self.dynamic_args = None
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim] q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim: 2 * rotary_dim] q2 = query[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim] k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim: 2 * rotary_dim] k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
@ -700,17 +712,11 @@ try:
head_size = query.shape[-1] head_size = query.shape[-1]
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding( pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
query,
key,
head_size,
cos,
sin,
True
)
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
@ -732,15 +738,16 @@ try:
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"], max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0, base=10000.0,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1 beta_slow=1,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -773,15 +780,16 @@ try:
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling["original_max_position_embeddings"], max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0, base=10000.0,
device=inv_freq.device, device=inv_freq.device,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
extrapolation_factor=1, extrapolation_factor=1,
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1 beta_slow=1,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -793,9 +801,9 @@ try:
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
@ -809,7 +817,7 @@ try:
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin( def get_cos_sin(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
): ):
""" """
Return cos and sin for the asked position ids Return cos and sin for the asked position ids
@ -827,7 +835,6 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
@ -840,14 +847,14 @@ try:
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings:
newbase = self.base * ( newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings) (self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq( self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device self.dim, newbase, self.inv_freq.device
@ -861,24 +868,28 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
import math import math
def find_correction_dim(
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): num_rotations, dim, base=10000, max_position_embeddings=2048
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) ):
return (
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
) / (2 * math.log(base))
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): def find_correction_range(
low = math.floor(find_correction_dim( low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
low_rot, dim, base, max_position_embeddings)) ):
high = math.ceil(find_correction_dim( low = math.floor(
high_rot, dim, base, max_position_embeddings)) find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim): def linear_ramp_mask(min, max, dim):
if min == max: if min == max:
max += 0.001 # Prevent singularity max += 0.001 # Prevent singularity
@ -887,16 +898,25 @@ try:
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale=1):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor, def __init__(
attn_factor, beta_fast, beta_slow): self,
dim,
max_position_embeddings,
base,
device,
scaling_factor,
*,
extrapolation_factor,
attn_factor,
beta_fast,
beta_slow,
):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
self.dim = dim self.dim = dim
@ -906,16 +926,17 @@ try:
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale = float(get_mscale( self.mscale = float(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings:
inv_freq_extrapolation = _create_inv_freq( inv_freq_extrapolation = _create_inv_freq(
@ -923,15 +944,26 @@ try:
) )
freqs = 1.0 / inv_freq_extrapolation freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, low, high = find_correction_range(
self.max_position_embeddings) self.beta_fast,
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to( self.beta_slow,
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation self.dim,
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask self.base,
self.max_position_embeddings,
)
inv_freq_mask = (
1
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(get_mscale( self.mscale = float(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)

View File

@ -2,6 +2,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass @dataclass
class Output: class Output:
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
@ -11,7 +12,9 @@ class Output:
class ResBlock(torch.nn.Module): class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True) self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU() self.act = torch.nn.SiLU()
def forward(self, x): def forward(self, x):
@ -19,15 +22,13 @@ class ResBlock(torch.nn.Module):
class MedusaModel(torch.nn.Module): class MedusaModel(torch.nn.Module):
def __init__( def __init__(self, config, weights, lm_head):
self,
config,
weights,
lm_head
):
super().__init__() super().__init__()
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] [
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
) )
self.lm_head = lm_head self.lm_head = lm_head
@ -40,9 +41,16 @@ class MedusaModel(torch.nn.Module):
class MedusaHead(torch.nn.Module): class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])]) self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks) n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False) self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:

View File

@ -7,23 +7,26 @@ from vllm import attention_ops
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, def reshape_and_cache(
slots: torch.Tensor): key: torch.Tensor,
cache_ops.reshape_and_cache( value: torch.Tensor,
key, value, key_cache, value_cache, slots key_cache: torch.Tensor,
) value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
def attention( def attention(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights # Copyright 2023 The vLLM team. All rights
@ -45,9 +48,7 @@ def attention(
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = ( max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
(max_s + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of

View File

@ -38,7 +38,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
os.makedirs(model_id, exist_ok=True) os.makedirs(model_id, exist_ok=True)
cache_dir = model_id cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}") logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(
base_model_id, trust_remote_code=trust_remote_code
)
model.save_pretrained(cache_dir, safe_serialization=True) model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir) model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)

View File

@ -1,12 +1,11 @@
SPECULATE = None SPECULATE = None
def get_speculate() -> int: def get_speculate() -> int:
global SPECULATE global SPECULATE
return SPECULATE return SPECULATE
def set_speculate(speculate: int): def set_speculate(speculate: int):
global SPECULATE global SPECULATE
SPECULATE = speculate SPECULATE = speculate

View File

@ -16,6 +16,7 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
@ -145,21 +146,31 @@ class StoppingCriteria:
pb.ignore_eos_token, pb.ignore_eos_token,
) )
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
def create_n_gram_speculation(
input_ids: torch.Tensor,
next_ids: torch.Tensor,
accepted_ids: torch.Tensor,
speculate: int,
verbose: bool,
):
# Very trivial approach, find first match in the string. # Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work # This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with # relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device. # much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0] B = accepted_ids.shape[0]
device = input_ids.device device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ] seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1 indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device) all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
speculate, device=device
)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1) all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices) speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids return speculative_ids
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
self, self,
@ -228,7 +239,15 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False): def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
speculate: int,
speculated_ids: Optional[torch.Tensor] = None,
speculative_scores: Optional[torch.Tensor] = None,
verbose=False,
):
if speculated_ids is not None: if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1) B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1 S = speculated_ids.shape[1] + 1
@ -249,12 +268,11 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B*S) next_ids = next_ids.view(B * S)
scores = scores.view( B* S, -1) scores = scores.view(B * S, -1)
if speculated_ids is not None: if speculated_ids is not None:
accepted_ids = [] accepted_ids = []
@ -262,7 +280,7 @@ class HeterogeneousNextTokenChooser:
S = speculated_ids.shape[1] + 1 S = speculated_ids.shape[1] + 1
indices = [] indices = []
for i in range(B): for i in range(B):
_next_ids = next_ids[i*S: (i + 1)*S] _next_ids = next_ids[i * S : (i + 1) * S]
_speculated_ids = speculated_ids[i] _speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S index = i * S
@ -278,7 +296,9 @@ class HeterogeneousNextTokenChooser:
break break
accepted_ids.append(accepted) accepted_ids.append(accepted)
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) accepted_ids = torch.tensor(
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices] next_ids = next_ids[indices]
scores = scores[indices] scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S indices = torch.arange(B, device=input_ids.device) * S
@ -296,7 +316,9 @@ class HeterogeneousNextTokenChooser:
speculative_ids = Greedy()(speculative_scores) speculative_ids = Greedy()(speculative_scores)
else: else:
# n-gram # n-gram
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose) speculative_ids = create_n_gram_speculation(
input_ids, next_ids, accepted_ids, speculate, verbose
)
else: else:
speculative_ids = None speculative_ids = None

View File

@ -16,7 +16,7 @@ class Weights:
dtype, dtype,
process_group, process_group,
aliases: Optional[Dict[str, List[str]]] = None, aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None prefix: Optional[str] = None,
): ):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
@ -213,7 +213,8 @@ class Weights:
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -283,7 +284,7 @@ class Weights:
if use_exllama: if use_exllama:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
else: else:
# The triton kernel reorders the scales/zero points instead of the weight/activation. # The triton kernel reorders the scales/zero points instead of the weight/activation.

View File

@ -21,14 +21,14 @@ def main():
block = [] block = []
for line in lines: for line in lines:
if line.startswith(" -") or line.startswith(" -"): if line.startswith(" -") or line.startswith(" -"):
rendered_block = '\n'.join(block) rendered_block = "\n".join(block)
if header: if header:
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
else: else:
final_doc += f"```shell\n{rendered_block}\n```\n" final_doc += f"```shell\n{rendered_block}\n```\n"
block = [] block = []
tokens = line.split("<") tokens = line.split("<")
if len(tokens)>1: if len(tokens) > 1:
header = tokens[-1][:-1] header = tokens[-1][:-1]
else: else:
header = line.split("--")[-1] header = line.split("--")[-1]
@ -36,7 +36,7 @@ def main():
block.append(line) block.append(line)
rendered_block = '\n'.join(block) rendered_block = "\n".join(block)
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
block = [] block = []