2023-12-11 04:46:30 -07:00
|
|
|
import torch
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
|
2023-12-11 04:46:30 -07:00
|
|
|
@dataclass
|
|
|
|
class Output:
|
|
|
|
logits: torch.FloatTensor = None
|
|
|
|
speculative_logits: torch.FloatTensor = None
|
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(torch.nn.Module):
|
|
|
|
def __init__(self, config, prefix, weights):
|
|
|
|
super().__init__()
|
2023-12-11 06:49:52 -07:00
|
|
|
self.linear = FastLinear.load(
|
|
|
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
self.act = torch.nn.SiLU()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x + self.act(self.linear(x))
|
|
|
|
|
|
|
|
|
|
|
|
class MedusaModel(torch.nn.Module):
|
2023-12-11 06:49:52 -07:00
|
|
|
def __init__(self, config, weights, lm_head):
|
2023-12-11 04:46:30 -07:00
|
|
|
super().__init__()
|
|
|
|
self.heads = torch.nn.ModuleList(
|
2023-12-11 06:49:52 -07:00
|
|
|
[
|
|
|
|
MedusaHead(config, prefix=f"{i}", weights=weights)
|
|
|
|
for i in range(config["medusa_num_heads"])
|
|
|
|
]
|
2023-12-11 04:46:30 -07:00
|
|
|
)
|
|
|
|
self.lm_head = lm_head
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
logits = self.lm_head(x)
|
|
|
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
|
|
|
return logits, speculative_logits
|
|
|
|
|
|
|
|
|
|
|
|
class MedusaHead(torch.nn.Module):
|
|
|
|
def __init__(self, config, prefix, weights):
|
|
|
|
super().__init__()
|
2023-12-11 06:49:52 -07:00
|
|
|
self.blocks = torch.nn.ModuleList(
|
|
|
|
[
|
|
|
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
|
|
|
for i in range(config["medusa_num_layers"])
|
|
|
|
]
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
n = len(self.blocks)
|
2023-12-11 06:49:52 -07:00
|
|
|
self.out = FastLinear.load(
|
|
|
|
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
for block in self.blocks:
|
|
|
|
x = block(x)
|
|
|
|
x = self.out(x)
|
|
|
|
return x
|