MLPSpeculator. (#1865)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Joshua Rosenkranz <joshua.rosenkranz@gmail.com>
This commit is contained in:
Nicolas Patry 2024-05-14 12:33:18 +02:00 committed by GitHub
parent 3136f27f36
commit e3d765645a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 399 additions and 139 deletions

View File

@ -3,11 +3,11 @@ from text_generation_server.layers.tensor_parallel import (
TensorParallelRowLinear,
TensorParallelEmbedding,
)
from text_generation_server.layers.speculative import SpeculativeHead
from text_generation_server.layers.linear import (
get_linear,
FastLinear,
)
from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm

View File

@ -69,21 +69,24 @@ class MedusaHeadV1(nn.Module):
from safetensors import safe_open
import json
use_medusa = config.use_medusa
speculator = config.speculator
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
path = speculator["path"]
medusa_config = str(Path(path) / "config.json")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
@ -108,10 +111,10 @@ class MedusaHeadV2(nn.Module):
from safetensors import safe_open
import json
use_medusa = config.use_medusa
speculator = config.speculator
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
medusa_config = str(Path(speculator) / "config.json")
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)

View File

@ -0,0 +1,176 @@
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""
def __init__(
self,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = [
FastLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(self.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
input_ids = logits.argmax(dim=-1)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
speculator_path = config.speculator["path"]
for fname in config.speculator["model_paths"]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)

View File

@ -1,34 +1,51 @@
import torch
import json
from typing import Tuple, Optional
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead
class SpeculativeHead(torch.nn.Module):
def __init__(self, lm_head, medusa):
def __init__(self, lm_head, speculator):
super().__init__()
self.head = lm_head
self.medusa = medusa
self.speculator = speculator
@staticmethod
def load(config, prefix: str, weights):
use_medusa = config.use_medusa
if use_medusa:
lm_head = None
speculator = config.speculator
if speculator:
speculator_path = config.speculator["path"]
speculator_config = str(speculator_path / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
config.speculator_config = speculator_config
try:
medusa = MedusaHeadV1.load(config, prefix, weights)
except:
medusa = MedusaHeadV2(config, prefix, weights)
architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel":
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except:
speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None
return SpeculativeHead(lm_head, medusa)
speculator = None
return SpeculativeHead(lm_head, speculator)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.medusa is not None:
return self.medusa(input)
if self.speculator is not None:
return self.speculator(input)
assert self.head is not None
logits = self.head(input)

View File

@ -1,9 +1,10 @@
import torch
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
@ -135,8 +136,9 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
use_medusa = None
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
medusa_revision = revision
@ -156,6 +158,8 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
@ -166,11 +170,70 @@ def get_model(
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
use_medusa = Path(medusa_config).parent
speculator = {
"path": Path(medusa_config).parent,
"model_paths": ["medusa_lm_head.safetensors"],
}
else:
use_medusa = Path(medusa_model_id)
speculator = {
"path": Path(medusa_model_id),
"model_paths": ["medusa_lm_head.safetensors"],
}
method = "medusa"
elif model_type == "mlp_speculator":
mlp_model_id = model_id
mlp_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists()
extension = ".safetensors"
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
api = HfApi()
info = api.model_info(mlp_model_id, revision=mlp_revision)
filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
for filename in filenames:
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename=filename,
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
}
else:
speculator = Path(mlp_model_id)
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
speculator = {"path": speculator, "model_paths": filenames}
method = "mlp_speculator"
else:
method = "n-gram"
@ -178,7 +241,6 @@ def get_model(
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict.get("model_type", None)
if model_type is None:
# TODO: fix how we determine model type for Mamba
if "ssm_cfg" in config_dict:
@ -202,7 +264,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -212,7 +274,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -227,7 +289,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -240,7 +302,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -250,7 +312,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -259,7 +321,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -270,7 +332,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -279,7 +341,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -288,7 +350,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -299,7 +361,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -308,7 +370,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -323,7 +385,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -334,7 +396,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -345,7 +407,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -355,7 +417,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -366,7 +428,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -377,7 +439,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -388,7 +450,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -399,7 +461,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -410,7 +472,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -424,7 +486,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -435,7 +497,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -444,7 +506,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -458,7 +520,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -469,7 +531,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -483,7 +545,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -494,7 +556,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -520,7 +582,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -544,7 +606,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -554,7 +616,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -564,7 +626,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -574,7 +636,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -586,7 +648,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -599,7 +661,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -623,7 +685,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -632,7 +694,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -644,7 +706,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -653,7 +715,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM):
)
config.pad_token_id = 3
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -482,12 +482,12 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")

View File

@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.use_medusa = config.use_medusa
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
config.text_config.speculator = config.speculator
vision_config = config.vision_config
self.text_model = load_text_model(

View File

@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa
config.text_config.speculator = config.speculator
self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,

View File

@ -1101,6 +1101,8 @@ class FlashCausalLM(Model):
next_token_texts = []
left = 0
logger.info(f"Accepted ids {n_accepted_ids}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token

View File

@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)

View File

@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)

View File

@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)

View File

@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)

View File

@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM):
config_cls=AutoConfig,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
# Set context windows
if getattr(config, "sliding_window", None) is not None:
@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral):
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if use_medusa:
if speculator:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM):
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
speculator, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
speculator, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)

View File

@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:

View File

@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM):
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
if config.quantize == "gptq":
weights._set_gptq_params(model_id, revision)

View File

@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM):
trust_remote_code=True,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)

View File

@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:

View File

@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM):
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM):
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained(

View File

@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM):
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM):
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -408,7 +408,7 @@ class Mamba(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -445,7 +445,7 @@ class Mamba(Model):
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -43,7 +43,7 @@ class MPTSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -76,7 +76,7 @@ class MPTSharded(CausalLM):
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)

View File

@ -22,7 +22,7 @@ class OPTSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -48,7 +48,7 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)

View File

@ -22,7 +22,7 @@ class Phi(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -53,7 +53,7 @@ class Phi(CausalLM):
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -12,11 +12,11 @@ class RW(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if use_medusa:
if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available():

View File

@ -19,7 +19,7 @@ class SantaCoder(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):

View File

@ -532,12 +532,12 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")

View File

@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM):
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained(
model_id,