chore: formatting
This commit is contained in:
parent
3a521c92b3
commit
72ee382ded
|
@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
rtol = 0.2
|
rtol = 0.2
|
||||||
|
|
||||||
def serialize(
|
def serialize(
|
||||||
self,
|
self,
|
||||||
data,
|
data,
|
||||||
|
@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
prefill_token.id == other.id
|
prefill_token.id == other.id
|
||||||
and prefill_token.text == other.text
|
and prefill_token.text == other.text
|
||||||
and (
|
and (
|
||||||
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
|
math.isclose(
|
||||||
|
prefill_token.logprob, other.logprob, rel_tol=self.rtol
|
||||||
|
)
|
||||||
if prefill_token.logprob is not None
|
if prefill_token.logprob is not None
|
||||||
else prefill_token.logprob == other.logprob
|
else prefill_token.logprob == other.logprob
|
||||||
)
|
)
|
||||||
|
@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
|
||||||
# Needed for GPTQ with exllama which has serious numerical fluctuations.
|
# Needed for GPTQ with exllama which has serious numerical fluctuations.
|
||||||
rtol = 0.75
|
rtol = 0.75
|
||||||
|
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}")
|
self.client = AsyncClient(f"http://localhost:{port}")
|
||||||
|
@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle):
|
||||||
def response_snapshot(snapshot):
|
def response_snapshot(snapshot):
|
||||||
return snapshot.use_extension(ResponseComparator)
|
return snapshot.use_extension(ResponseComparator)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def generous_response_snapshot(snapshot):
|
def generous_response_snapshot(snapshot):
|
||||||
return snapshot.use_extension(GenerousResponseComparator)
|
return snapshot.use_extension(GenerousResponseComparator)
|
||||||
|
@ -219,7 +224,7 @@ def launcher(event_loop):
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
dtype: Optional[str] = None
|
dtype: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
master_port = random.randint(10_000, 20_000)
|
master_port = random.randint(10_000, 20_000)
|
||||||
|
@ -282,7 +287,7 @@ def launcher(event_loop):
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
dtype: Optional[str] = None
|
dtype: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
|
@ -335,7 +340,7 @@ def launcher(event_loop):
|
||||||
],
|
],
|
||||||
volumes=volumes,
|
volumes=volumes,
|
||||||
ports={"80/tcp": port},
|
ports={"80/tcp": port},
|
||||||
shm_size="1G"
|
shm_size="1G",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ContainerLauncherHandle(client, container.name, port)
|
yield ContainerLauncherHandle(client, container.name, port)
|
||||||
|
|
|
@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4)
|
responses = await generate_load(
|
||||||
|
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
assert all(
|
||||||
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning'
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
|
||||||
|
)
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
assert responses[0].generated_text == ": Let n = 10 - 1"
|
assert responses[0].generated_text == ": Let n = 10 - 1"
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -3,7 +3,9 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def idefics_handle(launcher):
|
def idefics_handle(launcher):
|
||||||
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle:
|
with launcher(
|
||||||
|
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||||
)
|
)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids])
|
assert all(
|
||||||
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts])
|
[
|
||||||
|
token_id.item() == 10264
|
||||||
|
for generation in generations
|
||||||
|
for token_id in generation.tokens.token_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[
|
||||||
|
token_text == "Test"
|
||||||
|
for generation in generations
|
||||||
|
for token_text in generation.tokens.texts
|
||||||
|
]
|
||||||
|
)
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||||
)
|
)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids])
|
assert all(
|
||||||
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts])
|
[
|
||||||
|
token_id.item() == 13
|
||||||
|
for generation in generations
|
||||||
|
for token_id in generation.tokens.token_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[
|
||||||
|
token_text == "."
|
||||||
|
for generation in generations
|
||||||
|
for token_text in generation.tokens.texts
|
||||||
|
]
|
||||||
|
)
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
||||||
)
|
)
|
||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids])
|
assert all(
|
||||||
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts])
|
[
|
||||||
|
token_id.item() == 259
|
||||||
|
for generation in generations
|
||||||
|
for token_id in generation.tokens.token_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[
|
||||||
|
token_text == " "
|
||||||
|
for generation in generations
|
||||||
|
for token_text in generation.tokens.texts
|
||||||
|
]
|
||||||
|
)
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -77,12 +77,24 @@ def serve(
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = None if dtype is None else dtype.value
|
||||||
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}:
|
if dtype is not None and quantize not in {
|
||||||
|
None,
|
||||||
|
"bitsandbytes",
|
||||||
|
"bitsandbytes-nf4",
|
||||||
|
"bitsandbytes-fp4",
|
||||||
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path
|
model_id,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
|
uds_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,12 +152,17 @@ def download_weights(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
|
|
||||||
|
medusa_head = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="medusa_lm_head.pt"
|
||||||
|
)
|
||||||
if auto_convert:
|
if auto_convert:
|
||||||
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
|
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
|
||||||
if not medusa_sf.exists():
|
if not medusa_sf.exists():
|
||||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||||
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
|
medusa_config = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
@ -153,10 +170,17 @@ def download_weights(
|
||||||
revision = "main"
|
revision = "main"
|
||||||
try:
|
try:
|
||||||
utils.weight_files(model_id, revision, extension)
|
utils.weight_files(model_id, revision, extension)
|
||||||
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
|
logger.info(
|
||||||
|
f"Files for parent {model_id} are already present on the host. "
|
||||||
|
"Skipping download."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
# Local files not found
|
# Local files not found
|
||||||
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
except (
|
||||||
|
utils.LocalEntryNotFoundError,
|
||||||
|
FileNotFoundError,
|
||||||
|
utils.EntryNotFoundError,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -88,7 +88,6 @@ if MIXTRAL:
|
||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
|
@ -157,7 +156,9 @@ def get_model(
|
||||||
speculate_medusa = config_dict["medusa_num_heads"]
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
if speculate > speculate_medusa:
|
if speculate > speculate_medusa:
|
||||||
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match")
|
raise RuntimeError(
|
||||||
|
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
else:
|
else:
|
||||||
|
@ -249,7 +250,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
use_medusa=use_medusa
|
use_medusa=use_medusa,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
|
@ -313,7 +314,9 @@ def get_model(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
|
raise NotImplementedError(
|
||||||
|
"Mixtral models requires flash attention v2, stk and megablocks"
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
|
@ -354,7 +357,7 @@ def get_model(
|
||||||
raise ValueError("awq quantization is not supported for AutoModel")
|
raise ValueError("awq quantization is not supported for AutoModel")
|
||||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||||
raise ValueError("4bit quantization is not supported for AutoModel")
|
raise ValueError("4bit quantization is not supported for AutoModel")
|
||||||
elif (quantize == "eetq"):
|
elif quantize == "eetq":
|
||||||
raise ValueError("Eetq quantization is not supported for AutoModel")
|
raise ValueError("Eetq quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
|
|
|
@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
|
filenames,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
prefix="transformer",
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
|
@ -510,7 +510,11 @@ class CausalLM(Model):
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and torch.cuda.device_count() == 1
|
||||||
|
and quantize != "bitsandbytes"
|
||||||
|
):
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
|
@ -676,7 +680,10 @@ class CausalLM(Model):
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
|
prefill_token_ids,
|
||||||
|
prefill_logprobs,
|
||||||
|
prefill_texts,
|
||||||
|
is_special=[],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
@ -703,11 +710,11 @@ class CausalLM(Model):
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
Tokens(
|
Tokens(
|
||||||
[next_token_id_squeezed],
|
[next_token_id_squeezed],
|
||||||
[next_token_logprob],
|
[next_token_logprob],
|
||||||
[next_token_text],
|
[next_token_text],
|
||||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,9 +34,10 @@ from text_generation_server.utils.layers import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
class LlamaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module):
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
self.post_attention_layernorm = FastRMSNorm.load(
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
|
|
@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import (
|
||||||
|
attention,
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM,
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA,
|
||||||
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -35,7 +39,7 @@ from text_generation_server.utils.layers import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
|
|
|
@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import (
|
||||||
|
HAS_FLASH_ATTN_V2_ROCM,
|
||||||
|
HAS_FLASH_ATTN_V2_CUDA,
|
||||||
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -59,28 +62,28 @@ class MixtralConfig(PretrainedConfig):
|
||||||
model_type = "mixtral"
|
model_type = "mixtral"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=32000,
|
vocab_size=32000,
|
||||||
hidden_size=4096,
|
hidden_size=4096,
|
||||||
intermediate_size=14336,
|
intermediate_size=14336,
|
||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
num_key_value_heads=8,
|
num_key_value_heads=8,
|
||||||
hidden_act="silu",
|
hidden_act="silu",
|
||||||
max_position_embeddings=4096 * 32,
|
max_position_embeddings=4096 * 32,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
rms_norm_eps=1e-05,
|
rms_norm_eps=1e-05,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
pad_token_id=None,
|
pad_token_id=None,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pretraining_tp=1,
|
pretraining_tp=1,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
sliding_window=4096,
|
sliding_window=4096,
|
||||||
num_experts_per_tok=2,
|
num_experts_per_tok=2,
|
||||||
num_local_experts=8,
|
num_local_experts=8,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
@ -166,16 +169,18 @@ def _load_experts(config, prefix, mat, weights):
|
||||||
rank = weights.process_group.rank()
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
config.intermediate_size % world_size == 0
|
config.intermediate_size % world_size == 0
|
||||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
block_size = config.intermediate_size // world_size
|
block_size = config.intermediate_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
|
tensor = torch.empty(
|
||||||
dtype=weights.dtype,
|
(config.num_local_experts * block_size, config.hidden_size),
|
||||||
device=weights.device)
|
dtype=weights.dtype,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(config.num_local_experts):
|
for i in range(config.num_local_experts):
|
||||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||||
|
@ -184,16 +189,18 @@ def _load_experts(config, prefix, mat, weights):
|
||||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||||
else:
|
else:
|
||||||
expert_slice = slice_[start:stop]
|
expert_slice = slice_[start:stop]
|
||||||
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
|
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||||
|
dtype=weights.dtype
|
||||||
|
).to(device=weights.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(torch.nn.Module):
|
class MixtralAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = (
|
||||||
|
@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** -0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -219,7 +226,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
@ -236,17 +243,17 @@ class MixtralAttention(torch.nn.Module):
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
|
@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module):
|
||||||
# Indices for the sparse matrix. The indices for
|
# Indices for the sparse matrix. The indices for
|
||||||
# the intermediate matrix are dynamic depending
|
# the intermediate matrix are dynamic depending
|
||||||
# on the mapping of tokens to experts.
|
# on the mapping of tokens to experts.
|
||||||
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
|
column_indices = ops.topology(
|
||||||
blocks_per_row)
|
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||||
|
)
|
||||||
|
|
||||||
# For now, use meta init to save the device memory.
|
# For now, use meta init to save the device memory.
|
||||||
data = torch.empty(
|
data = torch.empty(
|
||||||
|
@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
# position of each bin.
|
# position of each bin.
|
||||||
|
|
||||||
# List of size num_experts
|
# List of size num_experts
|
||||||
padded_tokens_per_expert = round_up(tokens_per_expert,
|
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||||
self.blocking)
|
|
||||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||||
|
|
||||||
# Cumulative selected experts per token
|
# Cumulative selected experts per token
|
||||||
|
@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
|
|
||||||
# Permute tokens and pad to prepare expert computation
|
# Permute tokens and pad to prepare expert computation
|
||||||
# (top_k * sequence_length + padding, model_dim)
|
# (top_k * sequence_length + padding, model_dim)
|
||||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
|
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||||
self.top_k)
|
|
||||||
|
|
||||||
# Create the sparse matrix topology
|
# Create the sparse matrix topology
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module):
|
||||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||||
x = stk.Matrix(
|
x = stk.Matrix(
|
||||||
topo.size(),
|
topo.size(),
|
||||||
self.act(stk.ops.sdd(x, self.w1, topo).data) *
|
self.act(stk.ops.sdd(x, self.w1, topo).data)
|
||||||
stk.ops.sdd(x, self.w3, topo).data,
|
* stk.ops.sdd(x, self.w3, topo).data,
|
||||||
topo.row_indices,
|
topo.row_indices,
|
||||||
topo.column_indices,
|
topo.column_indices,
|
||||||
topo.offsets,
|
topo.offsets,
|
||||||
|
@ -537,7 +543,9 @@ class MixtralLayer(nn.Module):
|
||||||
self.self_attn = MixtralAttention(
|
self.self_attn = MixtralAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
self.block_sparse_moe = BlockSparseMoE(
|
||||||
|
f"{prefix}.block_sparse_moe", config, weights
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
@ -549,18 +557,18 @@ class MixtralLayer(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -615,16 +623,16 @@ class MixtralModel(torch.nn.Module):
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
@ -670,17 +678,17 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
raise ValueError("max_past cannot be None")
|
raise ValueError("max_past cannot be None")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
|
|
@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||||
image = image_url_or_urls
|
image = image_url_or_urls
|
||||||
|
|
||||||
if image.startswith("http://") or image.startswith("https://"):
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5))
|
response = requests.get(
|
||||||
|
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
content = response.content
|
content = response.content
|
||||||
elif image.startswith("data:"):
|
elif image.startswith("data:"):
|
||||||
|
|
|
@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM:
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import layernorm_ops
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
else:
|
else:
|
||||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
raise ValueError(
|
||||||
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaMLP
|
# this was adapted from LlamaMLP
|
||||||
|
@ -613,7 +616,12 @@ class IdeficsAttention(nn.Module):
|
||||||
|
|
||||||
query_shape = query_states.shape
|
query_shape = query_states.shape
|
||||||
key_shape = key_states.shape
|
key_shape = key_states.shape
|
||||||
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin)
|
self.rotary_emb(
|
||||||
|
query_states.view(-1, *query_shape[2:]),
|
||||||
|
key_states.reshape(-1, *key_shape[2:]),
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
|
|
||||||
query_states = query_states.view(query_shape)
|
query_states = query_states.view(query_shape)
|
||||||
key_states = key_states.view(key_shape)
|
key_states = key_states.view(key_shape)
|
||||||
|
|
|
@ -112,6 +112,7 @@ def is_url(string):
|
||||||
result = urlparse(string)
|
result = urlparse(string)
|
||||||
return all([result.scheme, result.netloc])
|
return all([result.scheme, result.netloc])
|
||||||
|
|
||||||
|
|
||||||
def is_image(string):
|
def is_image(string):
|
||||||
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
|
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
|
||||||
invalidated the url"""
|
invalidated the url"""
|
||||||
|
@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
|
|
||||||
image_objects = self.image_processor(image_objects, transform=transform)
|
image_objects = self.image_processor(image_objects, transform=transform)
|
||||||
|
|
||||||
|
|
||||||
text_encoding = self.tokenizer(
|
text_encoding = self.tokenizer(
|
||||||
text=full_text,
|
text=full_text,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
|
|
|
@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
|
@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
|
max_length = max(
|
||||||
|
max_length, input_length + max_new_tokens + speculative_length
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
|
@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
speculative_ids = (
|
||||||
|
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
blocks += b.blocks
|
||||||
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
speculative_length = (
|
||||||
|
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
|
)
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
|
@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None
|
speculative_ids = (
|
||||||
|
torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||||
|
if batches[0].speculative_ids is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
|
@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids
|
speculative_ids=speculative_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -727,43 +735,54 @@ class FlashCausalLM(Model):
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids=batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids=batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache=get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
block_tables=batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots=batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths=batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s=batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices=batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
B, speculative_length = speculative_ids.shape
|
B, speculative_length = speculative_ids.shape
|
||||||
new_length = speculative_length + 1
|
new_length = speculative_length + 1
|
||||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
arange_int = arange.to(dtype=torch.int32)
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
max_s = max_s + speculative_length
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
input_ids = new_input_ids
|
input_ids = new_input_ids
|
||||||
position_ids = new_position_ids
|
position_ids = new_position_ids
|
||||||
else:
|
else:
|
||||||
input_ids=batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids=batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache=get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
block_tables=batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots=batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths=batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s=batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices=batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -808,20 +827,31 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
speculative_logits = None
|
speculative_logits = None
|
||||||
|
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
if speculative_logits is not None:
|
if speculative_logits is not None:
|
||||||
speculative_logits = (
|
speculative_logits = (
|
||||||
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
|
speculative_logits[batch.prefill_next_token_indices]
|
||||||
|
if prefill_logprobs
|
||||||
|
else speculative_logits
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
next_input_ids,
|
||||||
|
next_token_logprobs,
|
||||||
|
logprobs,
|
||||||
|
accepted_ids,
|
||||||
|
speculative_ids,
|
||||||
|
) = batch.next_token_chooser(
|
||||||
|
batch.all_input_ids_tensor[:, : batch.max_seqlen],
|
||||||
|
next_token_logits,
|
||||||
|
get_speculate(),
|
||||||
|
batch.speculative_ids,
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
@ -851,11 +881,7 @@ class FlashCausalLM(Model):
|
||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
|
||||||
batch.input_lengths,
|
|
||||||
batch.all_input_ids,
|
|
||||||
accepted_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||||
# one, we need to first do a GPU <-> CPU sync
|
# one, we need to first do a GPU <-> CPU sync
|
||||||
|
@ -863,11 +889,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
index = 0
|
index = 0
|
||||||
for i, (
|
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
||||||
input_length,
|
|
||||||
all_input_ids,
|
|
||||||
n_accepted_ids
|
|
||||||
) in enumerate(iterator):
|
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
@ -901,7 +923,6 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
|
@ -983,8 +1004,10 @@ class FlashCausalLM(Model):
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
stopped = stopped and current_stopped
|
stopped = stopped and current_stopped
|
||||||
|
|
||||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left]
|
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
||||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left]
|
_next_token_logprobs = next_token_logprobs[
|
||||||
|
index : index + n_accepted_ids - left
|
||||||
|
]
|
||||||
index += n_accepted_ids
|
index += n_accepted_ids
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
|
@ -1027,7 +1050,10 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = []
|
prefill_token_ids,
|
||||||
|
request_prefill_logprobs,
|
||||||
|
prefill_texts,
|
||||||
|
is_special=[],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
|
@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM):
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
|
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
use_medusa, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
medusa_head = hf_hub_download(
|
||||||
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||||
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
)
|
||||||
|
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||||
|
weights = Weights(
|
||||||
|
[medusa_sf], device, dtype, process_group=self.process_group
|
||||||
|
)
|
||||||
lm_head = model.lm_head
|
lm_head = model.lm_head
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||||
|
|
||||||
|
|
|
@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
global SLIDING_WINDOW_BLOCKS
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
):
|
):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate:]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
|
max_length = max(
|
||||||
|
max_length, input_length + max_new_tokens + speculative_length
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device
|
||||||
|
@ -273,20 +275,20 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
speculative_ids=None
|
speculative_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseFlashMistral(FlashCausalLM):
|
class BaseFlashMistral(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_cls,
|
config_cls,
|
||||||
model_cls,
|
model_cls,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
global SLIDING_WINDOW_BLOCKS
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids=batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids=batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache=get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
block_tables=batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots=batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths=batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s=batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices=batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
B, speculative_length = speculative_ids.shape
|
B, speculative_length = speculative_ids.shape
|
||||||
new_length = speculative_length + 1
|
new_length = speculative_length + 1
|
||||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
arange_int = arange.to(dtype=torch.int32)
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
max_s = max_s + speculative_length
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
input_ids = new_input_ids
|
input_ids = new_input_ids
|
||||||
position_ids = new_position_ids
|
position_ids = new_position_ids
|
||||||
else:
|
else:
|
||||||
input_ids=batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids=batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache=get_cache_manager().kv_cache
|
kv_cache = get_cache_manager().kv_cache
|
||||||
block_tables=batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots=batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths=batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s=batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices=batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -401,12 +414,12 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
class FlashMistral(BaseFlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
super(FlashMistral, self).__init__(
|
super(FlashMistral, self).__init__(
|
||||||
config_cls=MistralConfig,
|
config_cls=MistralConfig,
|
||||||
|
@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral):
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,17 +3,20 @@ import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
||||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||||
|
MixtralConfig,
|
||||||
|
FlashMixtralForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashMixtral(BaseFlashMistral):
|
class FlashMixtral(BaseFlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
super(FlashMixtral, self).__init__(
|
super(FlashMixtral, self).__init__(
|
||||||
config_cls=MixtralConfig,
|
config_cls=MixtralConfig,
|
||||||
|
@ -22,5 +25,5 @@ class FlashMixtral(BaseFlashMistral):
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -792,7 +792,10 @@ class IdeficsCausalLM(Model):
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[]
|
prefill_token_ids,
|
||||||
|
prefill_logprobs,
|
||||||
|
prefill_texts,
|
||||||
|
is_special=[],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
@ -803,10 +806,10 @@ class IdeficsCausalLM(Model):
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
Tokens(
|
Tokens(
|
||||||
[next_token_id_squeezed],
|
[next_token_id_squeezed],
|
||||||
[next_token_logprob],
|
[next_token_logprob],
|
||||||
[next_token_text],
|
[next_token_text],
|
||||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
|
|
@ -56,7 +56,7 @@ class Model(ABC):
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
speculate=self.speculate
|
speculate=self.speculate,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -736,7 +736,7 @@ class Seq2SeqLM(Model):
|
||||||
[self.tokenizer.bos_token_id],
|
[self.tokenizer.bos_token_id],
|
||||||
[float("nan")],
|
[float("nan")],
|
||||||
[self.tokenizer.bos_token],
|
[self.tokenizer.bos_token],
|
||||||
[False]
|
[False],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
@ -763,10 +763,10 @@ class Seq2SeqLM(Model):
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
Tokens(
|
Tokens(
|
||||||
[next_token_id_squeezed],
|
[next_token_id_squeezed],
|
||||||
[next_token_logprob],
|
[next_token_logprob],
|
||||||
[next_token_text],
|
[next_token_text],
|
||||||
[next_token_id_squeezed.item() in self.all_special_ids],
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
),
|
),
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
|
|
|
@ -66,7 +66,10 @@ class Tokens:
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Tokens:
|
def to_pb(self) -> generate_pb2.Tokens:
|
||||||
return generate_pb2.Tokens(
|
return generate_pb2.Tokens(
|
||||||
ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special
|
ids=self.token_ids,
|
||||||
|
logprobs=self.logprobs,
|
||||||
|
texts=self.texts,
|
||||||
|
is_special=self.is_special,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -159,7 +159,13 @@ def serve(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
sharded,
|
||||||
|
quantize,
|
||||||
|
speculate,
|
||||||
|
dtype,
|
||||||
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
|
@ -207,5 +213,7 @@ def serve(
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code)
|
serve_inner(
|
||||||
|
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -51,7 +51,9 @@ except ImportError as e:
|
||||||
) from e
|
) from e
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
for idx in range(torch.cuda.device_count()):
|
for idx in range(torch.cuda.device_count()):
|
||||||
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
|
if "MI210" not in torch.cuda.get_device_name(
|
||||||
|
idx
|
||||||
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
)
|
)
|
||||||
|
@ -91,7 +93,9 @@ def attention(
|
||||||
)
|
)
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise ValueError(f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left}).")
|
raise ValueError(
|
||||||
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
|
||||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
|
|
@ -11,20 +11,22 @@ logger = getLogger(__name__)
|
||||||
try:
|
try:
|
||||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error('exllamav2_kernels not installed.')
|
logger.error("exllamav2_kernels not installed.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
none_tensor = torch.empty((1, 1), device="meta")
|
none_tensor = torch.empty((1, 1), device="meta")
|
||||||
|
|
||||||
|
|
||||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
output_shape = x.shape[:-1] + (q4_width,)
|
output_shape = x.shape[:-1] + (q4_width,)
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
|
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
|
||||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||||
return output.view(output_shape)
|
return output.view(output_shape)
|
||||||
|
|
||||||
|
|
||||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
"""
|
"""
|
||||||
Create Q matrix
|
Create Q matrix
|
||||||
|
@ -35,16 +37,18 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
w["q_scale_max"] /= 256
|
w["q_scale_max"] /= 256
|
||||||
w["q_perm"] = w["q_perm"].short()
|
w["q_perm"] = w["q_perm"].short()
|
||||||
w["q_invperm"] = w["q_invperm"].short()
|
w["q_invperm"] = w["q_invperm"].short()
|
||||||
return make_q_matrix(w["q_weight"],
|
return make_q_matrix(
|
||||||
w["q_perm"],
|
w["q_weight"],
|
||||||
w["q_invperm"],
|
w["q_perm"],
|
||||||
w["q_scale"],
|
w["q_invperm"],
|
||||||
w["q_scale_max"],
|
w["q_scale"],
|
||||||
w["q_groups"],
|
w["q_scale_max"],
|
||||||
none_tensor,
|
w["q_groups"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
temp_dq)
|
none_tensor,
|
||||||
|
temp_dq,
|
||||||
|
)
|
||||||
# GPTQ
|
# GPTQ
|
||||||
elif "qweight" in w:
|
elif "qweight" in w:
|
||||||
if w["scales"].dtype == torch.float:
|
if w["scales"].dtype == torch.float:
|
||||||
|
@ -52,31 +56,40 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
|
|
||||||
# GPTQ with g_idx (act_order)
|
# GPTQ with g_idx (act_order)
|
||||||
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
||||||
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
w["q_perm"] = torch.empty(
|
||||||
|
(w["qweight"].shape[0] * 8,),
|
||||||
|
dtype=torch.short,
|
||||||
|
device=w["qweight"].device,
|
||||||
|
)
|
||||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||||
return make_q_matrix(w["qweight"],
|
return make_q_matrix(
|
||||||
w["q_perm"],
|
w["qweight"],
|
||||||
w["q_invperm"],
|
w["q_perm"],
|
||||||
none_tensor,
|
w["q_invperm"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
w["qzeros"],
|
none_tensor,
|
||||||
w["scales"],
|
w["qzeros"],
|
||||||
w["g_idx"].cpu(),
|
w["scales"],
|
||||||
temp_dq)
|
w["g_idx"].cpu(),
|
||||||
|
temp_dq,
|
||||||
|
)
|
||||||
# GPTQ without g_idx
|
# GPTQ without g_idx
|
||||||
else:
|
else:
|
||||||
return make_q_matrix(w["qweight"],
|
return make_q_matrix(
|
||||||
none_tensor,
|
w["qweight"],
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
none_tensor,
|
none_tensor,
|
||||||
w["qzeros"],
|
none_tensor,
|
||||||
w["scales"],
|
w["qzeros"],
|
||||||
none_tensor,
|
w["scales"],
|
||||||
temp_dq)
|
none_tensor,
|
||||||
|
temp_dq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEVICE = None
|
DEVICE = None
|
||||||
FIXED_BYTES = 0
|
FIXED_BYTES = 0
|
||||||
|
@ -106,14 +119,15 @@ class QuantLinear(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits != 4:
|
if bits != 4:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
|
||||||
|
)
|
||||||
self.q_handle = None
|
self.q_handle = None
|
||||||
self.q_tensors = None
|
self.q_tensors = None
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.maxq = 2 ** self.bits - 1
|
self.maxq = 2**self.bits - 1
|
||||||
self.infeatures = qweight.shape[0] // self.bits * 32
|
self.infeatures = qweight.shape[0] // self.bits * 32
|
||||||
self.outfeatures = qweight.shape[1]
|
self.outfeatures = qweight.shape[1]
|
||||||
self.padding = - self.outfeatures % 32
|
self.padding = -self.outfeatures % 32
|
||||||
self.outfeatures = self.outfeatures + self.padding
|
self.outfeatures = self.outfeatures + self.padding
|
||||||
|
|
||||||
self.device = qweight.device
|
self.device = qweight.device
|
||||||
|
@ -128,9 +142,12 @@ class QuantLinear(nn.Module):
|
||||||
outfeatures = self.outfeatures
|
outfeatures = self.outfeatures
|
||||||
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
|
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
|
||||||
assert infeatures % self.group_size == 0
|
assert infeatures % self.group_size == 0
|
||||||
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits)
|
assert qzeros.shape == (
|
||||||
|
infeatures // self.group_size,
|
||||||
|
outfeatures // 32 * self.bits,
|
||||||
|
)
|
||||||
assert scales.shape == (infeatures // self.group_size, outfeatures)
|
assert scales.shape == (infeatures // self.group_size, outfeatures)
|
||||||
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}"
|
assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}"
|
||||||
|
|
||||||
global FIXED_BYTES, LAYERS
|
global FIXED_BYTES, LAYERS
|
||||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||||
|
@ -140,17 +157,15 @@ class QuantLinear(nn.Module):
|
||||||
assert self.qweight.device.type == "cuda"
|
assert self.qweight.device.type == "cuda"
|
||||||
assert self.qweight.device.index is not None
|
assert self.qweight.device.index is not None
|
||||||
self.q_tensors = {
|
self.q_tensors = {
|
||||||
"qweight":self.qweight,
|
"qweight": self.qweight,
|
||||||
"qzeros":self.qzeros,
|
"qzeros": self.qzeros,
|
||||||
"scales":self.scales,
|
"scales": self.scales,
|
||||||
"g_idx":self.g_idx
|
"g_idx": self.g_idx,
|
||||||
}
|
}
|
||||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||||
self.q_handle = ext_make_q_matrix(
|
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
|
||||||
self.q_tensors, temp_dq
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, force_cuda = False):
|
def forward(self, x, force_cuda=False):
|
||||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
@ -179,11 +194,14 @@ class ExLlamaV2DeviceTensors:
|
||||||
self.scratch_bytes = scratch_bytes
|
self.scratch_bytes = scratch_bytes
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = self.device)
|
self.scratch = torch.empty(
|
||||||
|
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def get_scratch_slice(self, size_bytes):
|
def get_scratch_slice(self, size_bytes):
|
||||||
|
|
||||||
if self.scratch is None: self.prepare()
|
if self.scratch is None:
|
||||||
|
self.prepare()
|
||||||
|
|
||||||
size_bytes = ((size_bytes + 127) // 128) * 128
|
size_bytes = ((size_bytes + 127) // 128) * 128
|
||||||
size_half = size_bytes // 2
|
size_half = size_bytes // 2
|
||||||
|
|
|
@ -35,7 +35,9 @@ HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||||
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding")
|
logger.warning(
|
||||||
|
"Disabling exllama v2 and using v1 instead because there are issues when sharding"
|
||||||
|
)
|
||||||
V2 = False
|
V2 = False
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
|
@ -43,17 +45,19 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
elif CAN_EXLLAMA:
|
elif CAN_EXLLAMA:
|
||||||
try:
|
try:
|
||||||
if V2:
|
if V2:
|
||||||
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
|
from text_generation_server.utils.gptq.exllamav2 import (
|
||||||
create_exllama_buffers,
|
QuantLinear as ExllamaQuantLinear,
|
||||||
set_device,
|
create_exllama_buffers,
|
||||||
)
|
set_device,
|
||||||
|
)
|
||||||
|
|
||||||
HAS_EXLLAMA = "2"
|
HAS_EXLLAMA = "2"
|
||||||
else:
|
else:
|
||||||
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
|
from text_generation_server.utils.gptq.exllama import (
|
||||||
create_exllama_buffers,
|
Ex4bitLinear as ExllamaQuantLinear,
|
||||||
set_device,
|
create_exllama_buffers,
|
||||||
)
|
set_device,
|
||||||
|
)
|
||||||
|
|
||||||
HAS_EXLLAMA = "1"
|
HAS_EXLLAMA = "1"
|
||||||
|
|
||||||
|
@ -114,7 +118,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_conv2d_no_bias(
|
def load_conv2d_no_bias(
|
||||||
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||||
):
|
):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
|
@ -138,9 +142,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
class FastLinear(nn.Module):
|
class FastLinear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
|
@ -164,9 +168,9 @@ class FastLinear(nn.Module):
|
||||||
|
|
||||||
class EETQLinear(nn.Module):
|
class EETQLinear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
device = weight.device
|
device = weight.device
|
||||||
|
@ -185,13 +189,13 @@ class EETQLinear(nn.Module):
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
has_fp16_weights=True,
|
has_fp16_weights=True,
|
||||||
memory_efficient_backward=False,
|
memory_efficient_backward=False,
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
assert (
|
||||||
|
@ -325,7 +329,9 @@ def get_linear(weight, bias, quantize):
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
linear = ExllamaQuantLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
linear = ExllamaQuantLinear(
|
||||||
|
qweight, qzeros, scales, g_idx, bias, bits, groupsize
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
linear = QuantLinear(
|
linear = QuantLinear(
|
||||||
qweight,
|
qweight,
|
||||||
|
@ -533,7 +539,6 @@ try:
|
||||||
else:
|
else:
|
||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
|
@ -569,7 +574,6 @@ try:
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class FastRMSNorm(nn.Module):
|
class FastRMSNorm(nn.Module):
|
||||||
def __init__(self, weight: torch.Tensor, eps: float):
|
def __init__(self, weight: torch.Tensor, eps: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -601,7 +605,11 @@ try:
|
||||||
return self.weight * hidden_states, residual
|
return self.weight * hidden_states, residual
|
||||||
elif IS_CUDA_SYSTEM:
|
elif IS_CUDA_SYSTEM:
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
(
|
||||||
|
normed_hidden_states,
|
||||||
|
res,
|
||||||
|
*rest,
|
||||||
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
self.weight,
|
self.weight,
|
||||||
|
@ -638,7 +646,8 @@ try:
|
||||||
return out, residual
|
return out, residual
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -650,14 +659,12 @@ try:
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import pos_encoding_ops
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
)
|
)
|
||||||
return inv_freq
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
def _get_rope_config(config):
|
def _get_rope_config(config):
|
||||||
if os.getenv("ROPE_SCALING", None) is not None:
|
if os.getenv("ROPE_SCALING", None) is not None:
|
||||||
rope_scaling = {
|
rope_scaling = {
|
||||||
|
@ -667,7 +674,6 @@ try:
|
||||||
return rope_scaling
|
return rope_scaling
|
||||||
return getattr(config, "rope_scaling", None)
|
return getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
|
||||||
class PositionRotaryEmbedding(nn.Module):
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, inv_freq, scaling_factor):
|
def __init__(self, inv_freq, scaling_factor):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -680,17 +686,23 @@ try:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
|
||||||
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
# Such controlflows may add some overhead.
|
# Such controlflows may add some overhead.
|
||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
q1 = query[..., :rotary_dim]
|
q1 = query[..., :rotary_dim]
|
||||||
q2 = query[..., rotary_dim: 2 * rotary_dim]
|
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
|
|
||||||
k1 = key[..., :rotary_dim]
|
k1 = key[..., :rotary_dim]
|
||||||
k2 = key[..., rotary_dim: 2 * rotary_dim]
|
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
|
@ -700,17 +712,11 @@ try:
|
||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(
|
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
query,
|
|
||||||
key,
|
|
||||||
head_size,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, config, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
|
@ -732,15 +738,16 @@ try:
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
base=10000.0,
|
base=10000.0,
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
extrapolation_factor=1,
|
extrapolation_factor=1,
|
||||||
attn_factor=1,
|
attn_factor=1,
|
||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1
|
beta_slow=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -773,15 +780,16 @@ try:
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
max_position_embeddings=rope_scaling[
|
||||||
|
"original_max_position_embeddings"
|
||||||
|
],
|
||||||
base=10000.0,
|
base=10000.0,
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
extrapolation_factor=1,
|
extrapolation_factor=1,
|
||||||
attn_factor=1,
|
attn_factor=1,
|
||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1
|
beta_slow=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -793,9 +801,9 @@ try:
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
@ -809,7 +817,7 @@ try:
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
def get_cos_sin(
|
def get_cos_sin(
|
||||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return cos and sin for the asked position ids
|
Return cos and sin for the asked position ids
|
||||||
|
@ -827,7 +835,6 @@ try:
|
||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
@ -840,14 +847,14 @@ try:
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
if seqlen > self.max_position_embeddings:
|
if seqlen > self.max_position_embeddings:
|
||||||
newbase = self.base * (
|
newbase = self.base * (
|
||||||
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||||
- (self.scaling_factor - 1)
|
- (self.scaling_factor - 1)
|
||||||
) ** (self.dim / (self.dim - 2))
|
) ** (self.dim / (self.dim - 2))
|
||||||
self.inv_freq = _create_inv_freq(
|
self.inv_freq = _create_inv_freq(
|
||||||
self.dim, newbase, self.inv_freq.device
|
self.dim, newbase, self.inv_freq.device
|
||||||
|
@ -861,24 +868,28 @@ try:
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
def find_correction_dim(
|
||||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
):
|
||||||
|
return (
|
||||||
|
dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
|
||||||
|
) / (2 * math.log(base))
|
||||||
|
|
||||||
# Find dim range bounds based on rotations
|
# Find dim range bounds based on rotations
|
||||||
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
def find_correction_range(
|
||||||
low = math.floor(find_correction_dim(
|
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||||
low_rot, dim, base, max_position_embeddings))
|
):
|
||||||
high = math.ceil(find_correction_dim(
|
low = math.floor(
|
||||||
high_rot, dim, base, max_position_embeddings))
|
find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||||
|
)
|
||||||
|
high = math.ceil(
|
||||||
|
find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||||
|
)
|
||||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
|
|
||||||
def linear_ramp_mask(min, max, dim):
|
def linear_ramp_mask(min, max, dim):
|
||||||
if min == max:
|
if min == max:
|
||||||
max += 0.001 # Prevent singularity
|
max += 0.001 # Prevent singularity
|
||||||
|
@ -887,16 +898,25 @@ try:
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
def get_mscale(scale=1):
|
def get_mscale(scale=1):
|
||||||
if scale <= 1:
|
if scale <= 1:
|
||||||
return 1.0
|
return 1.0
|
||||||
return 0.1 * math.log(scale) + 1.0
|
return 0.1 * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor,
|
def __init__(
|
||||||
attn_factor, beta_fast, beta_slow):
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
base,
|
||||||
|
device,
|
||||||
|
scaling_factor,
|
||||||
|
*,
|
||||||
|
extrapolation_factor,
|
||||||
|
attn_factor,
|
||||||
|
beta_fast,
|
||||||
|
beta_slow,
|
||||||
|
):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
@ -906,16 +926,17 @@ try:
|
||||||
self.attn_factor = attn_factor
|
self.attn_factor = attn_factor
|
||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
self.mscale = float(get_mscale(
|
self.mscale = float(
|
||||||
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
get_mscale(self.scaling_factor) * self.attn_factor
|
||||||
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
if seqlen > self.max_position_embeddings:
|
if seqlen > self.max_position_embeddings:
|
||||||
inv_freq_extrapolation = _create_inv_freq(
|
inv_freq_extrapolation = _create_inv_freq(
|
||||||
|
@ -923,15 +944,26 @@ try:
|
||||||
)
|
)
|
||||||
freqs = 1.0 / inv_freq_extrapolation
|
freqs = 1.0 / inv_freq_extrapolation
|
||||||
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||||
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
|
low, high = find_correction_range(
|
||||||
self.max_position_embeddings)
|
self.beta_fast,
|
||||||
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
|
self.beta_slow,
|
||||||
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
self.dim,
|
||||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
self.base,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)
|
||||||
|
inv_freq_mask = (
|
||||||
|
1
|
||||||
|
- linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||||
|
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq = (
|
||||||
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||||
|
+ inv_freq_extrapolation * inv_freq_mask
|
||||||
|
)
|
||||||
|
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
self.mscale = float(get_mscale(
|
self.mscale = float(
|
||||||
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
get_mscale(self.scaling_factor) * self.attn_factor
|
||||||
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Output:
|
class Output:
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
|
@ -11,7 +12,9 @@ class Output:
|
||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
|
self.linear = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||||
|
)
|
||||||
self.act = torch.nn.SiLU()
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -19,15 +22,13 @@ class ResBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MedusaModel(torch.nn.Module):
|
class MedusaModel(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(self, config, weights, lm_head):
|
||||||
self,
|
|
||||||
config,
|
|
||||||
weights,
|
|
||||||
lm_head
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = torch.nn.ModuleList(
|
self.heads = torch.nn.ModuleList(
|
||||||
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
[
|
||||||
|
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||||
|
for i in range(config["medusa_num_heads"])
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.lm_head = lm_head
|
self.lm_head = lm_head
|
||||||
|
|
||||||
|
@ -40,9 +41,16 @@ class MedusaModel(torch.nn.Module):
|
||||||
class MedusaHead(torch.nn.Module):
|
class MedusaHead(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
|
self.blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||||
|
for i in range(config["medusa_num_layers"])
|
||||||
|
]
|
||||||
|
)
|
||||||
n = len(self.blocks)
|
n = len(self.blocks)
|
||||||
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
|
self.out = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
|
|
|
@ -7,23 +7,26 @@ from vllm import attention_ops
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
|
def reshape_and_cache(
|
||||||
slots: torch.Tensor):
|
key: torch.Tensor,
|
||||||
cache_ops.reshape_and_cache(
|
value: torch.Tensor,
|
||||||
key, value, key_cache, value_cache, slots
|
key_cache: torch.Tensor,
|
||||||
)
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
@ -45,9 +48,7 @@ def attention(
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
(max_s + _PARTITION_SIZE - 1) //
|
|
||||||
_PARTITION_SIZE)
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
|
|
@ -38,7 +38,9 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
os.makedirs(model_id, exist_ok=True)
|
os.makedirs(model_id, exist_ok=True)
|
||||||
cache_dir = model_id
|
cache_dir = model_id
|
||||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=trust_remote_code)
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
base_model_id, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||||
model.config.save_pretrained(cache_dir)
|
model.config.save_pretrained(cache_dir)
|
||||||
tokenizer.save_pretrained(cache_dir)
|
tokenizer.save_pretrained(cache_dir)
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
|
|
||||||
SPECULATE = None
|
SPECULATE = None
|
||||||
|
|
||||||
|
|
||||||
def get_speculate() -> int:
|
def get_speculate() -> int:
|
||||||
global SPECULATE
|
global SPECULATE
|
||||||
return SPECULATE
|
return SPECULATE
|
||||||
|
|
||||||
|
|
||||||
def set_speculate(speculate: int):
|
def set_speculate(speculate: int):
|
||||||
global SPECULATE
|
global SPECULATE
|
||||||
SPECULATE = speculate
|
SPECULATE = speculate
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ from text_generation_server.utils.logits_process import (
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -145,21 +146,31 @@ class StoppingCriteria:
|
||||||
pb.ignore_eos_token,
|
pb.ignore_eos_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int, verbose: bool):
|
|
||||||
|
def create_n_gram_speculation(
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
next_ids: torch.Tensor,
|
||||||
|
accepted_ids: torch.Tensor,
|
||||||
|
speculate: int,
|
||||||
|
verbose: bool,
|
||||||
|
):
|
||||||
# Very trivial approach, find first match in the string.
|
# Very trivial approach, find first match in the string.
|
||||||
# This is much less refined than actual n-gram but seems to work
|
# This is much less refined than actual n-gram but seems to work
|
||||||
# relatively OK in grounded mode and is by far much faster with
|
# relatively OK in grounded mode and is by far much faster with
|
||||||
# much less worst case complexity as everything happens on device.
|
# much less worst case complexity as everything happens on device.
|
||||||
B = accepted_ids.shape[0]
|
B = accepted_ids.shape[0]
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
seeds = next_ids[accepted_ids.cumsum(dim=-1) -1 ]
|
seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
|
||||||
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(speculate, device=device)
|
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
|
||||||
|
speculate, device=device
|
||||||
|
)
|
||||||
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
|
||||||
|
|
||||||
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
|
||||||
return speculative_ids
|
return speculative_ids
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousNextTokenChooser:
|
class HeterogeneousNextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -228,7 +239,15 @@ class HeterogeneousNextTokenChooser:
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None, verbose=False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
scores: torch.Tensor,
|
||||||
|
speculate: int,
|
||||||
|
speculated_ids: Optional[torch.Tensor] = None,
|
||||||
|
speculative_scores: Optional[torch.Tensor] = None,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
|
||||||
S = speculated_ids.shape[1] + 1
|
S = speculated_ids.shape[1] + 1
|
||||||
|
@ -249,12 +268,11 @@ class HeterogeneousNextTokenChooser:
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
|
||||||
|
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = next_ids.view(B*S)
|
next_ids = next_ids.view(B * S)
|
||||||
scores = scores.view( B* S, -1)
|
scores = scores.view(B * S, -1)
|
||||||
|
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
accepted_ids = []
|
accepted_ids = []
|
||||||
|
@ -262,7 +280,7 @@ class HeterogeneousNextTokenChooser:
|
||||||
S = speculated_ids.shape[1] + 1
|
S = speculated_ids.shape[1] + 1
|
||||||
indices = []
|
indices = []
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
_next_ids = next_ids[i*S: (i + 1)*S]
|
_next_ids = next_ids[i * S : (i + 1) * S]
|
||||||
_speculated_ids = speculated_ids[i]
|
_speculated_ids = speculated_ids[i]
|
||||||
validate_speculative = _next_ids[:-1] == _speculated_ids
|
validate_speculative = _next_ids[:-1] == _speculated_ids
|
||||||
index = i * S
|
index = i * S
|
||||||
|
@ -278,7 +296,9 @@ class HeterogeneousNextTokenChooser:
|
||||||
break
|
break
|
||||||
accepted_ids.append(accepted)
|
accepted_ids.append(accepted)
|
||||||
|
|
||||||
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
|
accepted_ids = torch.tensor(
|
||||||
|
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
|
||||||
|
)
|
||||||
next_ids = next_ids[indices]
|
next_ids = next_ids[indices]
|
||||||
scores = scores[indices]
|
scores = scores[indices]
|
||||||
indices = torch.arange(B, device=input_ids.device) * S
|
indices = torch.arange(B, device=input_ids.device) * S
|
||||||
|
@ -296,7 +316,9 @@ class HeterogeneousNextTokenChooser:
|
||||||
speculative_ids = Greedy()(speculative_scores)
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
else:
|
else:
|
||||||
# n-gram
|
# n-gram
|
||||||
speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate, verbose)
|
speculative_ids = create_n_gram_speculation(
|
||||||
|
input_ids, next_ids, accepted_ids, speculate, verbose
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class Weights:
|
||||||
dtype,
|
dtype,
|
||||||
process_group,
|
process_group,
|
||||||
aliases: Optional[Dict[str, List[str]]] = None,
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
prefix: Optional[str] = None
|
prefix: Optional[str] = None,
|
||||||
):
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
@ -213,7 +213,8 @@ class Weights:
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq"
|
|
||||||
|
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq"
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
@ -283,7 +284,7 @@ class Weights:
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
g_idx = g_idx - g_idx[0]
|
g_idx = g_idx - g_idx[0]
|
||||||
else:
|
else:
|
||||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||||
|
|
|
@ -21,14 +21,14 @@ def main():
|
||||||
block = []
|
block = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.startswith(" -") or line.startswith(" -"):
|
if line.startswith(" -") or line.startswith(" -"):
|
||||||
rendered_block = '\n'.join(block)
|
rendered_block = "\n".join(block)
|
||||||
if header:
|
if header:
|
||||||
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
||||||
else:
|
else:
|
||||||
final_doc += f"```shell\n{rendered_block}\n```\n"
|
final_doc += f"```shell\n{rendered_block}\n```\n"
|
||||||
block = []
|
block = []
|
||||||
tokens = line.split("<")
|
tokens = line.split("<")
|
||||||
if len(tokens)>1:
|
if len(tokens) > 1:
|
||||||
header = tokens[-1][:-1]
|
header = tokens[-1][:-1]
|
||||||
else:
|
else:
|
||||||
header = line.split("--")[-1]
|
header = line.split("--")[-1]
|
||||||
|
@ -36,7 +36,7 @@ def main():
|
||||||
|
|
||||||
block.append(line)
|
block.append(line)
|
||||||
|
|
||||||
rendered_block = '\n'.join(block)
|
rendered_block = "\n".join(block)
|
||||||
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n"
|
||||||
block = []
|
block = []
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue