2024-05-13 04:44:30 -06:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from typing import Tuple, Optional
|
|
|
|
from text_generation_server.utils.speculate import get_speculate
|
|
|
|
from text_generation_server.layers.linear import FastLinear
|
|
|
|
from text_generation_server.layers.tensor_parallel import (
|
|
|
|
TensorParallelHead,
|
|
|
|
TensorParallelColumnLinear,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(torch.nn.Module):
|
|
|
|
def __init__(self, config, prefix, weights):
|
|
|
|
super().__init__()
|
|
|
|
self.linear = FastLinear.load(
|
|
|
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
|
|
|
)
|
|
|
|
self.act = torch.nn.SiLU()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x + self.act(self.linear(x))
|
|
|
|
|
|
|
|
|
|
|
|
class MedusaModel(torch.nn.Module):
|
|
|
|
def __init__(self, config, medusa_config, weights):
|
|
|
|
super().__init__()
|
|
|
|
self.heads = torch.nn.ModuleList(
|
|
|
|
[
|
|
|
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
|
|
|
for i in range(get_speculate())
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
|
|
|
return speculative_logits
|
|
|
|
|
|
|
|
|
|
|
|
class MedusaHead(torch.nn.Module):
|
|
|
|
def __init__(self, config, medusa_config, prefix, weights):
|
|
|
|
super().__init__()
|
|
|
|
self.blocks = torch.nn.ModuleList(
|
|
|
|
[
|
|
|
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
|
|
|
for i in range(medusa_config["medusa_num_layers"])
|
|
|
|
]
|
|
|
|
)
|
|
|
|
n = len(self.blocks)
|
|
|
|
self.out = FastLinear.load(
|
|
|
|
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
for block in self.blocks:
|
|
|
|
x = block(x)
|
|
|
|
x = self.out(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class MedusaHeadV1(nn.Module):
|
|
|
|
def __init__(self, lm_head, medusa):
|
|
|
|
super().__init__()
|
|
|
|
self.lm_head = lm_head
|
|
|
|
self.medusa = medusa
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(config, prefix: str, weights):
|
|
|
|
from pathlib import Path
|
|
|
|
from safetensors import safe_open
|
|
|
|
import json
|
|
|
|
|
2024-05-14 04:33:18 -06:00
|
|
|
speculator = config.speculator
|
2024-05-13 04:44:30 -06:00
|
|
|
|
2024-05-14 04:33:18 -06:00
|
|
|
path = speculator["path"]
|
|
|
|
medusa_config = str(Path(path) / "config.json")
|
2024-05-13 04:44:30 -06:00
|
|
|
|
2024-05-14 04:33:18 -06:00
|
|
|
for fname in speculator["model_paths"]:
|
|
|
|
filename = str(Path(path) / fname)
|
|
|
|
|
|
|
|
with open(medusa_config, "r") as f:
|
|
|
|
medusa_config = json.load(f)
|
|
|
|
routing = weights.routing
|
|
|
|
with safe_open(filename, framework="pytorch") as f:
|
|
|
|
for k in f.keys():
|
|
|
|
if k in routing and routing[k] != filename:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
|
|
)
|
|
|
|
routing[k] = filename
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
medusa = MedusaModel(config, medusa_config, weights)
|
|
|
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
|
|
|
return MedusaHeadV1(lm_head, medusa)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self, input: torch.Tensor
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
|
|
logits = self.lm_head(input)
|
|
|
|
# If we have too many tokens, we skip speculative logits
|
|
|
|
if input.shape[0] > 128:
|
|
|
|
return logits, None
|
|
|
|
|
|
|
|
speculative_logits = self.medusa(input)
|
|
|
|
return logits, speculative_logits
|
|
|
|
|
|
|
|
|
|
|
|
class MedusaHeadV2(nn.Module):
|
|
|
|
def __init__(self, config, prefix, weights):
|
|
|
|
super().__init__()
|
|
|
|
from pathlib import Path
|
|
|
|
from safetensors import safe_open
|
|
|
|
import json
|
|
|
|
|
2024-05-14 04:33:18 -06:00
|
|
|
speculator = config.speculator
|
2024-05-13 04:44:30 -06:00
|
|
|
|
2024-05-14 04:33:18 -06:00
|
|
|
medusa_config = str(Path(speculator) / "config.json")
|
|
|
|
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
with open(medusa_config, "r") as f:
|
|
|
|
medusa_config = json.load(f)
|
|
|
|
routing = weights.routing
|
|
|
|
with safe_open(filename, framework="pytorch") as f:
|
|
|
|
for k in f.keys():
|
|
|
|
if k in routing and routing[k] != filename:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
|
|
)
|
|
|
|
routing[k] = filename
|
|
|
|
|
|
|
|
self.n_medusa_heads = get_speculate()
|
|
|
|
|
|
|
|
assert medusa_config["medusa_num_layers"] == 1
|
|
|
|
self.linear = TensorParallelColumnLinear.load_multi(
|
|
|
|
config,
|
|
|
|
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
|
|
|
dim=0,
|
|
|
|
weights=weights,
|
|
|
|
bias=True,
|
|
|
|
)
|
|
|
|
self.process_group = weights.process_group
|
|
|
|
self.world_size = self.process_group.size()
|
|
|
|
self.rank = self.process_group.rank()
|
|
|
|
|
|
|
|
self.act = torch.nn.SiLU()
|
|
|
|
|
|
|
|
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# If we have too many tokens, we skip speculative logits
|
|
|
|
if x.shape[0] > 128:
|
|
|
|
logits = self.lm_head(x)
|
|
|
|
return logits, None
|
|
|
|
|
|
|
|
size = x.shape[-1]
|
|
|
|
block_size = (size + self.world_size - 1) // self.world_size
|
|
|
|
start = self.rank * block_size
|
|
|
|
stop = (self.rank + 1) * block_size
|
|
|
|
|
|
|
|
x_block = x[:, start:stop]
|
|
|
|
|
|
|
|
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
|
|
|
medusa_res = self.act(self.linear(x)).reshape(
|
|
|
|
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Apply all residual medusa heads
|
|
|
|
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
|
|
|
|
|
|
|
# Gather medusa heads
|
|
|
|
world_output = [
|
|
|
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
|
|
|
]
|
|
|
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
|
|
|
world_output = torch.cat(world_output, dim=-1)
|
|
|
|
|
|
|
|
# Stack x and medusa residual x
|
|
|
|
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
|
|
|
|
|
|
|
# Compute lm head on x + medusa residual x
|
|
|
|
logits = self.lm_head(stacked_x)
|
|
|
|
|
|
|
|
# Finally, split logits from speculative logits
|
|
|
|
logits, speculative_logits = torch.split(
|
|
|
|
logits, [1, self.n_medusa_heads], dim=-2
|
|
|
|
)
|
|
|
|
# Squeeze added dimension
|
|
|
|
logits = logits.squeeze(-2)
|
|
|
|
|
|
|
|
return logits, speculative_logits
|