Revamp medusa implementation so that every model can benefit. (#1588)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-02-26 19:49:28 +01:00 committed by GitHub
parent ac5a1c6f51
commit bf700e7eef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 352 additions and 283 deletions

View File

@ -236,6 +236,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -268,6 +269,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")
@ -302,6 +306,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -317,6 +322,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_medusa_handle(launcher): def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: with launcher(
"FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1"
) as handle:
yield handle yield handle

View File

@ -154,12 +154,8 @@ def download_weights(
import json import json
medusa_head = hf_hub_download( medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt" model_id, revision=revision, filename="medusa_lm_head.safetensors"
) )
if auto_convert:
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json" model_id, revision=revision, filename="config.json"
) )
@ -198,16 +194,12 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
elif (Path(model_id) / "medusa_lm_head.pt").exists(): elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
# Try to load as a local Medusa model # Try to load as a local Medusa model
try: try:
import json import json
medusa_head = Path(model_id) / "medusa_lm_head.pt" medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
if auto_convert:
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = Path(model_id) / "config.json" medusa_config = Path(model_id) / "config.json"
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)

View File

@ -3,7 +3,9 @@ import torch
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download
from typing import Optional from typing import Optional
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
@ -115,44 +117,14 @@ def get_model(
else: else:
set_speculate(0) set_speculate(0)
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return SantaCoder(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
use_medusa = None use_medusa = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
use_medusa = model_id medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
@ -169,6 +141,20 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=medusa_revision, filename="config.json"
)
hf_hub_download(
medusa_model_id,
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
use_medusa = Path(medusa_config).parent
else:
use_medusa = Path(medusa_model_id)
method = "medusa" method = "medusa"
else: else:
method = "n-gram" method = "n-gram"
@ -193,16 +179,22 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "gpt_bigcode": if (
model_type == "gpt_bigcode"
or model_type == "gpt2"
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -215,6 +207,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -224,6 +217,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -232,6 +226,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -242,6 +237,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -250,6 +246,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -258,6 +255,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -268,15 +266,16 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
else: else:
return CausalLM( return CausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -291,6 +290,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -301,9 +301,9 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@ -312,6 +312,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -321,9 +322,9 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
@ -334,6 +335,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -347,6 +349,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -357,6 +360,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -365,6 +369,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -378,6 +383,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -391,6 +397,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -400,6 +407,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -409,6 +417,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -418,6 +427,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -441,6 +451,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -449,6 +460,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -460,6 +472,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -468,6 +481,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM):
) )
logits = outputs.logits logits = outputs.logits
return logits, outputs.past_key_values return logits, speculative_logits, outputs.past_key_values

View File

@ -482,6 +482,7 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -550,7 +551,9 @@ class CausalLM(Model):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
]:
# Model Forward # Model Forward
kwargs = { kwargs = {
"input_ids": input_ids, "input_ids": input_ids,
@ -563,7 +566,11 @@ class CausalLM(Model):
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs) outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values if isinstance(outputs, tuple):
outputs, speculative_logits = outputs
else:
speculative_logits = None
return outputs.logits, speculative_logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -573,7 +580,7 @@ class CausalLM(Model):
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward( logits, speculative_logits, past = self.forward(
batch.input_ids, batch.input_ids,
attention_mask, attention_mask,
batch.position_ids, batch.position_ids,

View File

@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config, weights) self.transformer = BloomModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="word_embeddings", prefix="word_embeddings",
weights=weights, weights=weights,
@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
loss = None loss = None
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions( return (
loss=loss, CausalLMOutputWithCrossAttentions(
logits=lm_logits, loss=loss,
past_key_values=transformer_outputs.past_key_values, logits=logits,
hidden_states=transformer_outputs.hidden_states, past_key_values=transformer_outputs.past_key_values,
attentions=transformer_outputs.attentions, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
),
speculative_logits,
) )

View File

@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashGemmaModel(config, weights) self.model = FlashGemmaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
weights=weights, weights=weights,
@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits return logits, speculative_logits

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashLlamaModel(config, weights) self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits return logits, speculative_logits

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = MistralModel(config, weights) self.model = MistralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
) )
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = MixtralModel(config, weights) self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear, get_linear,
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights) self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastLayerNorm, FastLayerNorm,
) )
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) self.model = FlashPhiModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear, get_linear,
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self.transformer = FlashRWModel(config, weights) self.transformer = FlashRWModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
config, prefix="lm_head", weights=weights
)
def forward( def forward(
self, self,

View File

@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelHead, SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastLayerNorm, FastLayerNorm,
get_linear, get_linear,
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.transformer = FlashSantacoderModel(config, weights) self.transformer = FlashSantacoderModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )

View File

@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
PositionRotaryEmbedding, PositionRotaryEmbedding,
FastLinear, FastLinear,
) )
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
weights, weights,
) -> None: ) -> None:
super().__init__() super().__init__()
self.fc = TensorParallelHead.load( self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
config=config, prefix="lm_head", weights=weights
)
self.additional_fc = FastLinear.load( self.additional_fc = FastLinear.load(
config=config, config=config,
prefix="lm_head.additional_fc", prefix="lm_head.additional_fc",
@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.fc(input) output, speculative_logits = self.fc(input)
additional_features = self.additional_fc(input) additional_features = self.additional_fc(input)
output = torch.cat((output, additional_features), -1) output = torch.cat((output, additional_features), -1)
return output return output, speculative_logits
def extra_repr(self) -> str: def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters.""" """Overwriting `nn.Linear.extra_repr` to include new parameters."""
@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
loss = None loss = None
return CausalLMOutputWithPastImage( return (
loss=loss, CausalLMOutputWithPastImage(
logits=logits, loss=loss,
past_key_values=outputs.past_key_values, logits=logits,
hidden_states=outputs.hidden_states, past_key_values=outputs.past_key_values,
attentions=outputs.attentions, hidden_states=outputs.hidden_states,
image_hidden_states=outputs.image_hidden_states, attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
),
speculative_logits,
) )
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):

View File

@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastRMSNorm, FastRMSNorm,
FastLinear, FastLinear,
@ -205,14 +206,12 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load( self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
) )
self.lm_head = FastLinear.load( self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
config, f"{prefix}.embedding", weights, bias=False
)
self.config = config self.config = config
def forward( def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
hidden_states, residual, conv_state, ssm_state = block( hidden_states, residual, conv_state, ssm_state = block(
@ -226,8 +225,8 @@ class MambaModel(nn.Module):
) )
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape) hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params # update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1) inference_params.seqlen_offset += input_ids.size(1)
return logits return logits, speculative_logits

View File

@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
) )
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings") raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights) self.transformer = MPTModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )
self.logit_scale = None self.logit_scale = None
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache, use_cache=use_cache,
) )
logits = self.lm_head(outputs.last_hidden_state) logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
if self.logit_scale is not None: if self.logit_scale is not None:
if self.logit_scale == 0: if self.logit_scale == 0:
warnings.warn( warnings.warn(
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
loss = F.cross_entropy( loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
) )
return CausalLMOutputWithPast( return (
loss=loss, CausalLMOutputWithPast(
logits=logits, loss=loss,
past_key_values=outputs.past_key_values, logits=logits,
hidden_states=outputs.hidden_states, past_key_values=outputs.past_key_values,
attentions=outputs.attentions, hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
),
speculative_logits,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(

View File

@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights) self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
EPS = 1e-5 EPS = 1e-5
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
self.model = OPTModel(config, weights) self.model = OPTModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights config, prefix="model.decoder.embed_tokens", weights=weights
) )

View File

@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLinear, FastLinear,
) )
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.linear = TensorParallelHead.load( self.linear = SpeculativeHead.load(
config=config, prefix="lm_head.linear", weights=weights config=config, prefix="lm_head.linear", weights=weights
) )

View File

@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
) )
try: try:
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="lm_head", weights=weights config, prefix="lm_head", weights=weights
) )
except RuntimeError: except RuntimeError:
# Some models like t5-small were saved with shared weights unlike flan # Some models like t5-small were saved with shared weights unlike flan
# Since they are declared as the same arch we have no choice but hope # Since they are declared as the same arch we have no choice but hope
# that this is OK instead of using a proper flag. # that this is OK instead of using a proper flag.
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="shared", weights=weights config, prefix="shared", weights=weights
) )
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5) sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output) logits, speculative_logits = self.lm_head(sequence_output)
loss = None loss = None
if labels is not None: if labels is not None:
@ -1140,16 +1140,19 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput( return (
loss=loss, Seq2SeqLMOutput(
logits=lm_logits, loss=loss,
past_key_values=decoder_outputs.past_key_values, logits=logits,
decoder_hidden_states=decoder_outputs.hidden_states, past_key_values=decoder_outputs.past_key_values,
decoder_attentions=decoder_outputs.attentions, decoder_hidden_states=decoder_outputs.hidden_states,
cross_attentions=decoder_outputs.cross_attentions, decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, cross_attentions=decoder_outputs.cross_attentions,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_attentions=encoder_outputs.attentions, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
),
speculative_logits,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(

View File

@ -723,7 +723,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
@ -805,7 +807,9 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: def forward(
self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -900,9 +904,14 @@ class FlashCausalLM(Model):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -926,16 +935,11 @@ class FlashCausalLM(Model):
batch.slots = slots batch.slots = slots
try: try:
out = self.forward(batch) out, speculative_logits = self.forward(batch)
except Exception as e: except Exception as e:
del batch del batch
raise e raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out

View File

@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights) model = FlashGemmaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__( super(FlashGemma, self).__init__(

View File

@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,

View File

@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -406,9 +408,13 @@ class BaseFlashMistral(FlashCausalLM):
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self, batch: FlashMistralBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -479,7 +485,7 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -493,7 +499,7 @@ class BaseFlashMistral(FlashCausalLM):
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
@ -511,7 +517,13 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
@ -520,6 +532,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -529,6 +542,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)

View File

@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
trust_remote_code=True, trust_remote_code=True,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
config.vision_config.quantize = quantize config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = LlamaTokenizerFast.from_pretrained(

View File

@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
if self.has_position_ids: if self.has_position_ids:
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs) outputs, speculative_logits = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states return (
outputs.logits,
speculative_logits,
outputs.past_key_values,
outputs.image_hidden_states,
)
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
:, : -batch.padding_right_offset :, : -batch.padding_right_offset
] ]
logits, past, image_hidden_states = self.forward( logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids, input_ids=batch.input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=batch.position_ids, position_ids=batch.position_ids,

View File

@ -408,6 +408,7 @@ class Mamba(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -444,6 +445,7 @@ class Mamba(Model):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
@ -505,7 +507,7 @@ class Mamba(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, inference_params=inference_params input_ids=input_ids, inference_params=inference_params
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -514,6 +516,7 @@ class Mamba(Model):
"inference_params": inference_params, "inference_params": inference_params,
"graph": graph, "graph": graph,
"logits": logits, "logits": logits,
"speculative_logits": speculative_logits,
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
@ -556,9 +559,14 @@ class Mamba(Model):
inference_params.ssm_states.copy_( inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs] cuda_graph["inference_params"].ssm_states[:, :bs]
) )
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
@ -589,7 +597,9 @@ class Mamba(Model):
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits = self.forward(input_ids, inference_params=batch.inference_params) logits, speculative_logits = self.forward(
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params # batch.inference_params = new_inference_params
# Results # Results

View File

@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
config = json.load(f) config = json.load(f)
config = PretrainedConfig(**config) config = PretrainedConfig(**config)
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,6 +22,7 @@ class Phi(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -52,6 +53,7 @@ class Phi(CausalLM):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):

View File

@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -596,6 +597,7 @@ class Seq2SeqLM(Model):
past_key_values: Optional = None, past_key_values: Optional = None,
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
Optional[torch.Tensor],
torch.Tensor, torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
@ -609,8 +611,15 @@ class Seq2SeqLM(Model):
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
if isinstance(outputs, tuple):
# Our custom models
outputs, speculative_logits = outputs
else:
# Generic transformers models
speculative_logits = None
return ( return (
outputs.logits, outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )
@ -635,7 +644,7 @@ class Seq2SeqLM(Model):
else: else:
encoder_last_hidden_state = None encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward( logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
batch.decoder_input_ids, batch.decoder_input_ids,

View File

@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
return ( return (
outputs.logits, outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )

View File

@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
and "arguments" not in s.rfilename and "arguments" not in s.rfilename
and "args" not in s.rfilename and "args" not in s.rfilename
and "training" not in s.rfilename and "training" not in s.rfilename
and "medusa_lm_head" not in s.rfilename
] ]
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
and "args" not in f and "args" not in f
and "adapter" not in f and "adapter" not in f
and "training" not in f and "training" not in f
and "medusa_lm_head" not in f
] ]
return filenames return filenames

View File

@ -4,7 +4,7 @@ import torch.distributed
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import List from typing import List, Tuple, Optional
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
@ -380,6 +380,96 @@ class SuperLayer(nn.Module):
return self.linear.forward(x) return self.linear.forward(x)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
def forward(self, x):
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path
from safetensors import safe_open
import json
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
weights.routing[k] = filename
medusa = MedusaModel(config, weights)
else:
medusa = None
return SpeculativeHead(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool): def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear) super().__init__(linear)

View File

@ -1,59 +0,0 @@
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights, lm_head):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x