Revamp medusa implementation so that every model can benefit. (#1588)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
ac5a1c6f51
commit
bf700e7eef
|
@ -236,6 +236,7 @@ def launcher(event_loop):
|
|||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
@ -268,6 +269,9 @@ def launcher(event_loop):
|
|||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
|
@ -302,6 +306,7 @@ def launcher(event_loop):
|
|||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
|
@ -317,6 +322,9 @@ def launcher(event_loop):
|
|||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@ import pytest
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_medusa_handle(launcher):
|
||||
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
|
||||
with launcher(
|
||||
"FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
|
|
@ -154,12 +154,8 @@ def download_weights(
|
|||
import json
|
||||
|
||||
medusa_head = hf_hub_download(
|
||||
model_id, revision=revision, filename="medusa_lm_head.pt"
|
||||
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
||||
)
|
||||
if auto_convert:
|
||||
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_config = hf_hub_download(
|
||||
model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
|
@ -198,16 +194,12 @@ def download_weights(
|
|||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
elif (Path(model_id) / "medusa_lm_head.pt").exists():
|
||||
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
|
||||
# Try to load as a local Medusa model
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.pt"
|
||||
if auto_convert:
|
||||
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
medusa_config = Path(model_id) / "config.json"
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
|
|
@ -3,7 +3,9 @@ import torch
|
|||
from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
|
@ -115,44 +117,14 @@ def get_model(
|
|||
else:
|
||||
set_speculate(0)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
return SantaCoder(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
use_medusa = None
|
||||
if "medusa_num_heads" in config_dict:
|
||||
use_medusa = model_id
|
||||
medusa_model_id = model_id
|
||||
medusa_revision = revision
|
||||
model_id = config_dict["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
speculate_medusa = config_dict["medusa_num_heads"]
|
||||
|
@ -169,6 +141,20 @@ def get_model(
|
|||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local = Path(medusa_model_id).exists()
|
||||
if not is_local:
|
||||
medusa_config = hf_hub_download(
|
||||
medusa_model_id, revision=medusa_revision, filename="config.json"
|
||||
)
|
||||
hf_hub_download(
|
||||
medusa_model_id,
|
||||
revision=medusa_revision,
|
||||
filename="medusa_lm_head.safetensors",
|
||||
)
|
||||
use_medusa = Path(medusa_config).parent
|
||||
else:
|
||||
use_medusa = Path(medusa_model_id)
|
||||
|
||||
method = "medusa"
|
||||
else:
|
||||
method = "n-gram"
|
||||
|
@ -193,16 +179,22 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if (
|
||||
model_type == "gpt_bigcode"
|
||||
or model_type == "gpt2"
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -215,6 +207,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -224,6 +217,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -232,6 +226,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -242,6 +237,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -250,6 +246,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -258,6 +255,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -268,15 +266,16 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -291,6 +290,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -301,9 +301,9 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
|
@ -312,6 +312,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -321,9 +322,9 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
|
@ -334,6 +335,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -347,6 +349,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -357,6 +360,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -365,6 +369,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -378,6 +383,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -391,6 +397,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -400,6 +407,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -409,6 +417,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -418,6 +427,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -441,6 +451,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -449,6 +460,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -460,6 +472,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -468,6 +481,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
|
|
@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
|
|||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM):
|
|||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
|
@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM):
|
|||
)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, outputs.past_key_values
|
||||
return logits, speculative_logits, outputs.past_key_values
|
||||
|
|
|
@ -482,6 +482,7 @@ class CausalLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -550,7 +551,9 @@ class CausalLM(Model):
|
|||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
|
@ -563,7 +566,11 @@ class CausalLM(Model):
|
|||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
if isinstance(outputs, tuple):
|
||||
outputs, speculative_logits = outputs
|
||||
else:
|
||||
speculative_logits = None
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
|
@ -573,7 +580,7 @@ class CausalLM(Model):
|
|||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
logits, past = self.forward(
|
||||
logits, speculative_logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
|
|
|
@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
|
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||
super().__init__(config)
|
||||
self.transformer = BloomModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="word_embeddings",
|
||||
weights=weights,
|
||||
|
@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
return (
|
||||
CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
|
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.model = FlashGemmaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||
weights=weights,
|
||||
|
@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
|
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
|
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
|
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.model = MixtralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
|
|
@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
|
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
super().__init__(config)
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.model = FlashPhiModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
|
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
|
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
|
||||
self.transformer = FlashRWModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
|
|||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastLayerNorm,
|
||||
get_linear,
|
||||
|
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
|
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||
weights,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.fc = TensorParallelHead.load(
|
||||
config=config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
|
||||
self.additional_fc = FastLinear.load(
|
||||
config=config,
|
||||
prefix="lm_head.additional_fc",
|
||||
|
@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = self.fc(input)
|
||||
output, speculative_logits = self.fc(input)
|
||||
additional_features = self.additional_fc(input)
|
||||
output = torch.cat((output, additional_features), -1)
|
||||
|
||||
return output
|
||||
return output, speculative_logits
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
||||
|
@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
|
||||
return CausalLMOutputWithPastImage(
|
||||
return (
|
||||
CausalLMOutputWithPastImage(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
|
|
|
@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastRMSNorm,
|
||||
FastLinear,
|
||||
|
@ -205,14 +206,12 @@ class MambaModel(nn.Module):
|
|||
self.norm_f = FastRMSNorm.load(
|
||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.lm_head = FastLinear.load(
|
||||
config, f"{prefix}.embedding", weights, bias=False
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, residual, conv_state, ssm_state = block(
|
||||
|
@ -226,8 +225,8 @@ class MambaModel(nn.Module):
|
|||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||
hidden_states = hidden_states.view(residual.shape)
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
# update the offset for the next inference using these params
|
||||
inference_params.seqlen_offset += input_ids.size(1)
|
||||
return logits
|
||||
return logits, speculative_logits
|
||||
|
|
|
@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||
if not config.tie_word_embeddings:
|
||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||
self.transformer = MPTModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
self.logit_scale = None
|
||||
|
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||
if self.logit_scale is not None:
|
||||
if self.logit_scale == 0:
|
||||
warnings.warn(
|
||||
|
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
||||
)
|
||||
return CausalLMOutputWithPast(
|
||||
return (
|
||||
CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
|
|
@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
|
||||
|
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
EPS = 1e-5
|
||||
|
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||
|
||||
self.model = OPTModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
|
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
|
|||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.linear = TensorParallelHead.load(
|
||||
self.linear = SpeculativeHead.load(
|
||||
config=config, prefix="lm_head.linear", weights=weights
|
||||
)
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
)
|
||||
|
||||
try:
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
except RuntimeError:
|
||||
# Some models like t5-small were saved with shared weights unlike flan
|
||||
# Since they are declared as the same arch we have no choice but hope
|
||||
# that this is OK instead of using a proper flag.
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="shared", weights=weights
|
||||
)
|
||||
|
||||
|
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
logits, speculative_logits = self.lm_head(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1140,9 +1140,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
return (
|
||||
Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
|
@ -1150,6 +1151,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
|
|
@ -723,7 +723,7 @@ class FlashCausalLM(Model):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
|
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
|
|||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
|
@ -805,7 +807,9 @@ class FlashCausalLM(Model):
|
|||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
|
@ -900,9 +904,14 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
|
@ -926,16 +935,11 @@ class FlashCausalLM(Model):
|
|||
batch.slots = slots
|
||||
|
||||
try:
|
||||
out = self.forward(batch)
|
||||
out, speculative_logits = self.forward(batch)
|
||||
except Exception as e:
|
||||
del batch
|
||||
raise e
|
||||
|
||||
if isinstance(out, tuple):
|
||||
out, speculative_logits = out
|
||||
else:
|
||||
speculative_logits = None
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||
|
|
|
@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
|
|||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM):
|
|||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGemmaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
|
|
|
@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
|
|||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM):
|
|||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
|
|
@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
|
@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
|
@ -406,9 +408,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, batch: FlashMistralBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
|
@ -479,7 +485,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
|
@ -493,7 +499,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
|
@ -511,7 +517,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
|
@ -520,6 +532,7 @@ class FlashMistral(BaseFlashMistral):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -529,6 +542,7 @@ class FlashMistral(BaseFlashMistral):
|
|||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
|
|
@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
|
|
@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
|
|||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||
)
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
|
|
@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.vision_config.quantize = quantize
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
|
|
|
@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
|
|||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
||||
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.past_key_values,
|
||||
outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
|
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
|
|||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, past, image_hidden_states = self.forward(
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=batch.position_ids,
|
||||
|
|
|
@ -408,6 +408,7 @@ class Mamba(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -444,6 +445,7 @@ class Mamba(Model):
|
|||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
@ -505,7 +507,7 @@ class Mamba(Model):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
logits = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids, inference_params=inference_params
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -514,6 +516,7 @@ class Mamba(Model):
|
|||
"inference_params": inference_params,
|
||||
"graph": graph,
|
||||
"logits": logits,
|
||||
"speculative_logits": speculative_logits,
|
||||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
|
@ -556,9 +559,14 @@ class Mamba(Model):
|
|||
inference_params.ssm_states.copy_(
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||
)
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
|
@ -589,7 +597,9 @@ class Mamba(Model):
|
|||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||
logits, speculative_logits = self.forward(
|
||||
input_ids, inference_params=batch.inference_params
|
||||
)
|
||||
|
||||
# batch.inference_params = new_inference_params
|
||||
# Results
|
||||
|
|
|
@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
|
|||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
|
|
@ -22,6 +22,7 @@ class Phi(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -52,6 +53,7 @@ class Phi(CausalLM):
|
|||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
|
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
|
|
@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -596,6 +597,7 @@ class Seq2SeqLM(Model):
|
|||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
|
@ -609,8 +611,15 @@ class Seq2SeqLM(Model):
|
|||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
if isinstance(outputs, tuple):
|
||||
# Our custom models
|
||||
outputs, speculative_logits = outputs
|
||||
else:
|
||||
# Generic transformers models
|
||||
speculative_logits = None
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
@ -635,7 +644,7 @@ class Seq2SeqLM(Model):
|
|||
else:
|
||||
encoder_last_hidden_state = None
|
||||
|
||||
logits, encoder_last_hidden_state, past = self.forward(
|
||||
logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
batch.decoder_input_ids,
|
||||
|
|
|
@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
|
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
|
|
@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
|
|||
and "arguments" not in s.rfilename
|
||||
and "args" not in s.rfilename
|
||||
and "training" not in s.rfilename
|
||||
and "medusa_lm_head" not in s.rfilename
|
||||
]
|
||||
|
||||
|
||||
|
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||
and "args" not in f
|
||||
and "adapter" not in f
|
||||
and "training" not in f
|
||||
and "medusa_lm_head" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.distributed
|
|||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import List
|
||||
from typing import List, Tuple, Optional
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
|
||||
|
@ -380,6 +380,96 @@ class SuperLayer(nn.Module):
|
|||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(
|
||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||
)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(
|
||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
weights.routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, weights)
|
||||
else:
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
import torch
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||
|
||||
|
||||
@dataclass
|
||||
class Output:
|
||||
logits: torch.FloatTensor = None
|
||||
speculative_logits: torch.FloatTensor = None
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(
|
||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||
)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights, lm_head):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
self.lm_head = lm_head
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.lm_head(x)
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(
|
||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
Loading…
Reference in New Issue