feat: medusa v2 (#1734)

This commit is contained in:
OlivierDehaene 2024-04-12 16:24:45 +02:00 committed by GitHub
parent 1b2670c823
commit eefea5ee31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 146 additions and 45 deletions

View File

@ -145,7 +145,7 @@ def get_model(
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)

View File

@ -814,7 +814,7 @@ class FlashCausalLM(Model):
for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")
return int(num_blocks * BLOCK_SIZE)
@ -874,22 +874,14 @@ class FlashCausalLM(Model):
lm_head_indices = batch.prefill_head_indices
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,

View File

@ -432,12 +432,12 @@ class ResBlock(torch.nn.Module):
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, config, medusa_config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
]
)
@ -447,12 +447,12 @@ class MedusaModel(torch.nn.Module):
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, medusa_config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
for i in range(medusa_config["medusa_num_layers"])
]
)
n = len(self.blocks)
@ -467,7 +467,7 @@ class MedusaHead(torch.nn.Module):
return x
class SpeculativeHead(nn.Module):
class MedusaHeadV1(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
@ -475,38 +475,147 @@ class SpeculativeHead(nn.Module):
@staticmethod
def load(config, prefix: str, weights):
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path
from safetensors import safe_open
import json
use_medusa = config.use_medusa
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
weights.routing[k] = filename
routing[k] = filename
medusa = MedusaModel(config, weights)
medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MedusaHeadV1(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input)
return logits, speculative_logits
class MedusaHeadV2(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
from pathlib import Path
from safetensors import safe_open
import json
use_medusa = config.use_medusa
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"]
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
dim=0,
weights=weights,
bias=True,
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
self.act = torch.nn.SiLU()
self.lm_head = TensorParallelHead.load(config, prefix, weights)
def forward(self, x):
size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size
x_block = x[:, start:stop]
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
)
# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res
# Gather medusa heads
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
# Compute lm head on x + medusa residual x
logits = self.lm_head(stacked_x)
# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(
logits, [1, self.n_medusa_heads], dim=-2
)
# Squeeze added dimension
logits = logits.squeeze(-2)
return logits, speculative_logits
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
use_medusa = config.use_medusa
if use_medusa:
lm_head = None
try:
medusa = MedusaHeadV1.load(config, prefix, weights)
except:
medusa = MedusaHeadV2(config, prefix, weights)
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None
return SpeculativeHead(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
if self.medusa is not None:
return self.medusa(input)
assert self.head is not None
logits = self.head(input)
return logits, None
class TensorParallelHead(SuperLayer):