MLPSpeculator. (#1865)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Joshua Rosenkranz <joshua.rosenkranz@gmail.com>
This commit is contained in:
parent
3136f27f36
commit
e3d765645a
|
@ -3,11 +3,11 @@ from text_generation_server.layers.tensor_parallel import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.speculative import SpeculativeHead
|
|
||||||
from text_generation_server.layers.linear import (
|
from text_generation_server.layers.linear import (
|
||||||
get_linear,
|
get_linear,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.speculative import SpeculativeHead
|
||||||
|
|
||||||
# Just to add the `load` methods.
|
# Just to add the `load` methods.
|
||||||
from text_generation_server.layers.layernorm import load_layer_norm
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
|
|
|
@ -69,21 +69,24 @@ class MedusaHeadV1(nn.Module):
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
path = speculator["path"]
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
medusa_config = str(Path(path) / "config.json")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
for fname in speculator["model_paths"]:
|
||||||
medusa_config = json.load(f)
|
filename = str(Path(path) / fname)
|
||||||
routing = weights.routing
|
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with open(medusa_config, "r") as f:
|
||||||
for k in f.keys():
|
medusa_config = json.load(f)
|
||||||
if k in routing and routing[k] != filename:
|
routing = weights.routing
|
||||||
raise RuntimeError(
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
for k in f.keys():
|
||||||
)
|
if k in routing and routing[k] != filename:
|
||||||
routing[k] = filename
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
medusa = MedusaModel(config, medusa_config, weights)
|
medusa = MedusaModel(config, medusa_config, weights)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
@ -108,10 +111,10 @@ class MedusaHeadV2(nn.Module):
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
medusa_config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
|
||||||
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
A L2 normalization implementation
|
||||||
|
...
|
||||||
|
Args
|
||||||
|
----
|
||||||
|
normalized_shape : int
|
||||||
|
Dimensionality of input data (size of final tensor axis)
|
||||||
|
elementwise_scale_weight : torch.Tensor
|
||||||
|
learned scaling term after normalization?
|
||||||
|
elementwise_shift_bias : torch.Tensor
|
||||||
|
learned bias term after normalization?
|
||||||
|
eps : float
|
||||||
|
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
eps=1e-06,
|
||||||
|
):
|
||||||
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
|
self.weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
self.bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
x = self.weight * x
|
||||||
|
x = x + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.emb = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.proj = [
|
||||||
|
FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.{i}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
self.head = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
# TODO
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb[i](ind)
|
||||||
|
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||||
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorHead(nn.Module):
|
||||||
|
def __init__(self, lm_head, mlp_speculator):
|
||||||
|
super().__init__()
|
||||||
|
self.lm_head = lm_head
|
||||||
|
self.mlp_speculator = mlp_speculator
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
# If we have too many tokens, we skip speculative logits
|
||||||
|
if input.shape[0] > 128:
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
input_ids = logits.argmax(dim=-1)
|
||||||
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
|
||||||
|
for fname in config.speculator["model_paths"]:
|
||||||
|
filename = str(Path(speculator_path) / fname)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
|
@ -1,34 +1,51 @@
|
||||||
import torch
|
import torch
|
||||||
|
import json
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
|
|
||||||
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
||||||
|
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
|
||||||
|
from text_generation_server.layers.mlp import MLPSpeculatorHead
|
||||||
|
|
||||||
|
|
||||||
class SpeculativeHead(torch.nn.Module):
|
class SpeculativeHead(torch.nn.Module):
|
||||||
def __init__(self, lm_head, medusa):
|
def __init__(self, lm_head, speculator):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head = lm_head
|
self.head = lm_head
|
||||||
self.medusa = medusa
|
self.speculator = speculator
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
if use_medusa:
|
if speculator:
|
||||||
lm_head = None
|
speculator_path = config.speculator["path"]
|
||||||
|
speculator_config = str(speculator_path / "config.json")
|
||||||
|
|
||||||
|
with open(speculator_config, "r") as f:
|
||||||
|
speculator_config = json.load(f)
|
||||||
|
|
||||||
|
config.speculator_config = speculator_config
|
||||||
try:
|
try:
|
||||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
architecture = speculator_config["architectures"][0]
|
||||||
except:
|
|
||||||
medusa = MedusaHeadV2(config, prefix, weights)
|
if architecture == "MLPSpeculatorPreTrainedModel":
|
||||||
|
speculator = MLPSpeculatorHead.load(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
speculator = None
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
speculator = MedusaHeadV1.load(config, prefix, weights)
|
||||||
|
except:
|
||||||
|
speculator = MedusaHeadV2(config, prefix, weights)
|
||||||
|
lm_head = None
|
||||||
else:
|
else:
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
medusa = None
|
speculator = None
|
||||||
return SpeculativeHead(lm_head, medusa)
|
return SpeculativeHead(lm_head, speculator)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if self.medusa is not None:
|
if self.speculator is not None:
|
||||||
return self.medusa(input)
|
return self.speculator(input)
|
||||||
|
|
||||||
assert self.head is not None
|
assert self.head is not None
|
||||||
logits = self.head(input)
|
logits = self.head(input)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -135,8 +136,9 @@ def get_model(
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
use_medusa = None
|
speculator = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
medusa_model_id = model_id
|
medusa_model_id = model_id
|
||||||
medusa_revision = revision
|
medusa_revision = revision
|
||||||
|
@ -156,6 +158,8 @@ def get_model(
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
is_local = Path(medusa_model_id).exists()
|
is_local = Path(medusa_model_id).exists()
|
||||||
if not is_local:
|
if not is_local:
|
||||||
medusa_config = hf_hub_download(
|
medusa_config = hf_hub_download(
|
||||||
|
@ -166,11 +170,70 @@ def get_model(
|
||||||
revision=medusa_revision,
|
revision=medusa_revision,
|
||||||
filename="medusa_lm_head.safetensors",
|
filename="medusa_lm_head.safetensors",
|
||||||
)
|
)
|
||||||
use_medusa = Path(medusa_config).parent
|
speculator = {
|
||||||
|
"path": Path(medusa_config).parent,
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
use_medusa = Path(medusa_model_id)
|
speculator = {
|
||||||
|
"path": Path(medusa_model_id),
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
|
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
|
elif model_type == "mlp_speculator":
|
||||||
|
mlp_model_id = model_id
|
||||||
|
mlp_revision = revision
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
speculate_mlp = config_dict["n_predict"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_mlp:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_mlp)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
# Reload model type from parent.
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
is_local = Path(mlp_model_id).exists()
|
||||||
|
extension = ".safetensors"
|
||||||
|
if not is_local:
|
||||||
|
mlp_speculator_config = hf_hub_download(
|
||||||
|
mlp_model_id, revision=mlp_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
api = HfApi()
|
||||||
|
info = api.model_info(mlp_model_id, revision=mlp_revision)
|
||||||
|
filenames = [
|
||||||
|
s.rfilename
|
||||||
|
for s in info.siblings
|
||||||
|
if s.rfilename.endswith(extension)
|
||||||
|
and len(s.rfilename.split("/")) == 1
|
||||||
|
and "arguments" not in s.rfilename
|
||||||
|
and "args" not in s.rfilename
|
||||||
|
and "training" not in s.rfilename
|
||||||
|
]
|
||||||
|
for filename in filenames:
|
||||||
|
hf_hub_download(
|
||||||
|
mlp_model_id,
|
||||||
|
revision=mlp_revision,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
speculator = {
|
||||||
|
"path": Path(mlp_speculator_config).parent,
|
||||||
|
"model_paths": filenames,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
speculator = Path(mlp_model_id)
|
||||||
|
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
|
||||||
|
speculator = {"path": speculator, "model_paths": filenames}
|
||||||
|
method = "mlp_speculator"
|
||||||
else:
|
else:
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
|
|
||||||
|
@ -178,7 +241,6 @@ def get_model(
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
model_type = config_dict.get("model_type", None)
|
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
# TODO: fix how we determine model type for Mamba
|
# TODO: fix how we determine model type for Mamba
|
||||||
if "ssm_cfg" in config_dict:
|
if "ssm_cfg" in config_dict:
|
||||||
|
@ -202,7 +264,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -212,7 +274,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -227,7 +289,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -240,7 +302,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -250,7 +312,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -259,7 +321,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -270,7 +332,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -279,7 +341,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -288,7 +350,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -299,7 +361,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -308,7 +370,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -323,7 +385,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -334,7 +396,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -345,7 +407,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -355,7 +417,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -366,7 +428,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -377,7 +439,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -388,7 +450,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -399,7 +461,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -410,7 +472,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -424,7 +486,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -435,7 +497,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -444,7 +506,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -458,7 +520,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -469,7 +531,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -483,7 +545,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -494,7 +556,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -520,7 +582,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -544,7 +606,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -554,7 +616,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -564,7 +626,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -574,7 +636,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -586,7 +648,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -599,7 +661,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -623,7 +685,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -632,7 +694,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -644,7 +706,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -653,7 +715,7 @@ def get_model(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM):
|
||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
|
@ -482,12 +482,12 @@ class CausalLM(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if use_medusa:
|
if speculator:
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
|
@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
config.vision_config.use_medusa = config.use_medusa
|
config.vision_config.speculator = config.speculator
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.use_medusa = config.use_medusa
|
config.text_config.speculator = config.speculator
|
||||||
|
|
||||||
vision_config = config.vision_config
|
vision_config = config.vision_config
|
||||||
self.text_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
|
|
|
@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.use_medusa = config.use_medusa
|
config.text_config.speculator = config.speculator
|
||||||
self.language_model = load_text_model(
|
self.language_model = load_text_model(
|
||||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
|
|
|
@ -1101,6 +1101,8 @@ class FlashCausalLM(Model):
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
|
logger.info(f"Accepted ids {n_accepted_ids}")
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
# Generated token
|
# Generated token
|
||||||
|
|
|
@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
|
|
@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
config_cls=AutoConfig,
|
config_cls=AutoConfig,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
|
@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
|
@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral):
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
|
@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM):
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashPhiForCausalLM(config, weights)
|
model = FlashPhiForCausalLM(config, weights)
|
||||||
if use_medusa:
|
if speculator:
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM):
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
is_local_model = (
|
is_local_model = (
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
Path(speculator).exists() and Path(speculator).is_dir()
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
medusa_config = hf_hub_download(
|
medusa_config = hf_hub_download(
|
||||||
use_medusa, revision=revision, filename="config.json"
|
speculator, revision=revision, filename="config.json"
|
||||||
)
|
)
|
||||||
medusa_head = hf_hub_download(
|
medusa_head = hf_hub_download(
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
speculator, revision=revision, filename="medusa_lm_head.pt"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
|
|
|
@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
|
@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
|
|
|
@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM):
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
|
@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
|
@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
config.vision_config.quantize = quantize
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
|
|
@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM):
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM):
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -408,7 +408,7 @@ class Mamba(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -445,7 +445,7 @@ class Mamba(Model):
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
|
|
@ -43,7 +43,7 @@ class MPTSharded(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -76,7 +76,7 @@ class MPTSharded(CausalLM):
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
config = PretrainedConfig(**config)
|
config = PretrainedConfig(**config)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ class OPTSharded(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -48,7 +48,7 @@ class OPTSharded(CausalLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class Phi(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -53,7 +53,7 @@ class Phi(CausalLM):
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
|
|
@ -12,11 +12,11 @@ class RW(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if use_medusa:
|
if speculator:
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
|
@ -19,7 +19,7 @@ class SantaCoder(CausalLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
|
|
@ -532,12 +532,12 @@ class Seq2SeqLM(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if use_medusa:
|
if speculator:
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
|
@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
Loading…
Reference in New Issue