2024-05-13 04:44:30 -06:00
|
|
|
import torch
|
2024-05-14 04:33:18 -06:00
|
|
|
import json
|
2024-05-13 04:44:30 -06:00
|
|
|
from typing import Tuple, Optional
|
|
|
|
from text_generation_server.layers.tensor_parallel import TensorParallelHead
|
2024-05-14 04:33:18 -06:00
|
|
|
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
|
|
|
|
from text_generation_server.layers.mlp import MLPSpeculatorHead
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
|
|
|
|
class SpeculativeHead(torch.nn.Module):
|
2024-05-14 04:33:18 -06:00
|
|
|
def __init__(self, lm_head, speculator):
|
2024-05-13 04:44:30 -06:00
|
|
|
super().__init__()
|
|
|
|
self.head = lm_head
|
2024-05-14 04:33:18 -06:00
|
|
|
self.speculator = speculator
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(config, prefix: str, weights):
|
2024-05-14 04:33:18 -06:00
|
|
|
speculator = config.speculator
|
|
|
|
if speculator:
|
|
|
|
speculator_path = config.speculator["path"]
|
|
|
|
speculator_config = str(speculator_path / "config.json")
|
|
|
|
|
|
|
|
with open(speculator_config, "r") as f:
|
|
|
|
speculator_config = json.load(f)
|
|
|
|
|
|
|
|
config.speculator_config = speculator_config
|
2024-05-13 04:44:30 -06:00
|
|
|
try:
|
2024-05-14 04:33:18 -06:00
|
|
|
architecture = speculator_config["architectures"][0]
|
|
|
|
|
|
|
|
if architecture == "MLPSpeculatorPreTrainedModel":
|
|
|
|
speculator = MLPSpeculatorHead.load(config, prefix, weights)
|
|
|
|
else:
|
|
|
|
speculator = None
|
|
|
|
except KeyError:
|
|
|
|
try:
|
|
|
|
speculator = MedusaHeadV1.load(config, prefix, weights)
|
|
|
|
except:
|
|
|
|
speculator = MedusaHeadV2(config, prefix, weights)
|
|
|
|
lm_head = None
|
2024-05-13 04:44:30 -06:00
|
|
|
else:
|
|
|
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
2024-05-14 04:33:18 -06:00
|
|
|
speculator = None
|
|
|
|
return SpeculativeHead(lm_head, speculator)
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self, input: torch.Tensor
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
2024-05-14 04:33:18 -06:00
|
|
|
if self.speculator is not None:
|
|
|
|
return self.speculator(input)
|
2024-05-13 04:44:30 -06:00
|
|
|
|
|
|
|
assert self.head is not None
|
|
|
|
logits = self.head(input)
|
|
|
|
return logits, None
|