MLPSpeculator. (#1865)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Joshua Rosenkranz <joshua.rosenkranz@gmail.com>
This commit is contained in:
Nicolas Patry 2024-05-14 12:33:18 +02:00 committed by GitHub
parent 3136f27f36
commit e3d765645a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 399 additions and 139 deletions

View File

@ -3,11 +3,11 @@ from text_generation_server.layers.tensor_parallel import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.layers.speculative import SpeculativeHead
from text_generation_server.layers.linear import ( from text_generation_server.layers.linear import (
get_linear, get_linear,
FastLinear, FastLinear,
) )
from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods. # Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.layernorm import load_layer_norm

View File

@ -69,10 +69,13 @@ class MedusaHeadV1(nn.Module):
from safetensors import safe_open from safetensors import safe_open
import json import json
use_medusa = config.use_medusa speculator = config.speculator
medusa_config = str(Path(use_medusa) / "config.json") path = speculator["path"]
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") medusa_config = str(Path(path) / "config.json")
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
medusa_config = json.load(f) medusa_config = json.load(f)
@ -108,10 +111,10 @@ class MedusaHeadV2(nn.Module):
from safetensors import safe_open from safetensors import safe_open
import json import json
use_medusa = config.use_medusa speculator = config.speculator
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(speculator) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") filename = str(Path(speculator) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
medusa_config = json.load(f) medusa_config = json.load(f)

View File

@ -0,0 +1,176 @@
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""
def __init__(
self,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = [
FastLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(self.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
input_ids = logits.argmax(dim=-1)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
speculator_path = config.speculator["path"]
for fname in config.speculator["model_paths"]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)

View File

@ -1,34 +1,51 @@
import torch import torch
import json
from typing import Tuple, Optional from typing import Tuple, Optional
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.tensor_parallel import TensorParallelHead from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead
class SpeculativeHead(torch.nn.Module): class SpeculativeHead(torch.nn.Module):
def __init__(self, lm_head, medusa): def __init__(self, lm_head, speculator):
super().__init__() super().__init__()
self.head = lm_head self.head = lm_head
self.medusa = medusa self.speculator = speculator
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
use_medusa = config.use_medusa speculator = config.speculator
if use_medusa: if speculator:
lm_head = None speculator_path = config.speculator["path"]
speculator_config = str(speculator_path / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
config.speculator_config = speculator_config
try: try:
medusa = MedusaHeadV1.load(config, prefix, weights) architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel":
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except: except:
medusa = MedusaHeadV2(config, prefix, weights) speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else: else:
lm_head = TensorParallelHead.load(config, prefix, weights) lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None speculator = None
return SpeculativeHead(lm_head, medusa) return SpeculativeHead(lm_head, speculator)
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.medusa is not None: if self.speculator is not None:
return self.medusa(input) return self.speculator(input)
assert self.head is not None assert self.head is not None
logits = self.head(input) logits = self.head(input)

View File

@ -1,9 +1,10 @@
import torch import torch
import os
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
@ -135,8 +136,9 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
model_type = config_dict.get("model_type", None)
use_medusa = None speculator = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
medusa_model_id = model_id medusa_model_id = model_id
medusa_revision = revision medusa_revision = revision
@ -156,6 +158,8 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists() is_local = Path(medusa_model_id).exists()
if not is_local: if not is_local:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
@ -166,11 +170,70 @@ def get_model(
revision=medusa_revision, revision=medusa_revision,
filename="medusa_lm_head.safetensors", filename="medusa_lm_head.safetensors",
) )
use_medusa = Path(medusa_config).parent speculator = {
"path": Path(medusa_config).parent,
"model_paths": ["medusa_lm_head.safetensors"],
}
else: else:
use_medusa = Path(medusa_model_id) speculator = {
"path": Path(medusa_model_id),
"model_paths": ["medusa_lm_head.safetensors"],
}
method = "medusa" method = "medusa"
elif model_type == "mlp_speculator":
mlp_model_id = model_id
mlp_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists()
extension = ".safetensors"
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
api = HfApi()
info = api.model_info(mlp_model_id, revision=mlp_revision)
filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
for filename in filenames:
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename=filename,
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
}
else:
speculator = Path(mlp_model_id)
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
speculator = {"path": speculator, "model_paths": filenames}
method = "mlp_speculator"
else: else:
method = "n-gram" method = "n-gram"
@ -178,7 +241,6 @@ def get_model(
if speculate > 0: if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.") logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict.get("model_type", None)
if model_type is None: if model_type is None:
# TODO: fix how we determine model type for Mamba # TODO: fix how we determine model type for Mamba
if "ssm_cfg" in config_dict: if "ssm_cfg" in config_dict:
@ -202,7 +264,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -212,7 +274,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -227,7 +289,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -240,7 +302,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -250,7 +312,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -259,7 +321,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -270,7 +332,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -279,7 +341,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -288,7 +350,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -299,7 +361,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -308,7 +370,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -323,7 +385,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -334,7 +396,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -345,7 +407,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -355,7 +417,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -366,7 +428,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -377,7 +439,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -388,7 +450,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -399,7 +461,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -410,7 +472,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -424,7 +486,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -435,7 +497,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -444,7 +506,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -458,7 +520,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -469,7 +531,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -483,7 +545,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -494,7 +556,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -520,7 +582,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -544,7 +606,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -554,7 +616,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -564,7 +626,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -574,7 +636,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -586,7 +648,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -599,7 +661,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -623,7 +685,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -632,7 +694,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -644,7 +706,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -653,7 +715,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -482,12 +482,12 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa: if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")

View File

@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
config.vision_config.use_medusa = config.use_medusa config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa config.text_config.speculator = config.speculator
vision_config = config.vision_config vision_config = config.vision_config
self.text_model = load_text_model( self.text_model = load_text_model(

View File

@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.config = config self.config = config
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.use_medusa = config.use_medusa config.text_config.speculator = config.speculator
self.language_model = load_text_model( self.language_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,

View File

@ -1101,6 +1101,8 @@ class FlashCausalLM(Model):
next_token_texts = [] next_token_texts = []
left = 0 left = 0
logger.info(f"Accepted ids {n_accepted_ids}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
# Generated token # Generated token

View File

@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM):
config_cls=AutoConfig, config_cls=AutoConfig,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer, tokenizer_class=AutoTokenizer,
@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
# Set context windows # Set context windows
if getattr(config, "sliding_window", None) is not None: if getattr(config, "sliding_window", None) is not None:
@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights) model = FlashPhiForCausalLM(config, weights)
if use_medusa: if speculator:
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM):
from pathlib import Path from pathlib import Path
is_local_model = ( is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir() Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model: if not is_local_model:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json" speculator, revision=revision, filename="config.json"
) )
medusa_head = hf_hub_download( medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt" speculator, revision=revision, filename="medusa_lm_head.pt"
) )
else: else:
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)

View File

@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:

View File

@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)

View File

@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM):
trust_remote_code=True, trust_remote_code=True,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:

View File

@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM):
) )
config.quantize = quantize config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
config.vision_config.quantize = quantize config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = LlamaTokenizerFast.from_pretrained(

View File

@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -408,7 +408,7 @@ class Mamba(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -445,7 +445,7 @@ class Mamba(Model):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -43,7 +43,7 @@ class MPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -76,7 +76,7 @@ class MPTSharded(CausalLM):
config = json.load(f) config = json.load(f)
config = PretrainedConfig(**config) config = PretrainedConfig(**config)
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,7 +22,7 @@ class OPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -48,7 +48,7 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,7 +22,7 @@ class Phi(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -53,7 +53,7 @@ class Phi(CausalLM):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -12,11 +12,11 @@ class RW(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa: if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel") raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -19,7 +19,7 @@ class SantaCoder(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):

View File

@ -532,12 +532,12 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa: if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")

View File

@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,