feat: medusa v2 (#1734)
This commit is contained in:
parent
1b2670c823
commit
eefea5ee31
|
@ -145,7 +145,7 @@ def get_model(
|
|||
if speculate is not None:
|
||||
if speculate > speculate_medusa:
|
||||
raise RuntimeError(
|
||||
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||
)
|
||||
else:
|
||||
set_speculate(speculate)
|
||||
|
|
|
@ -814,7 +814,7 @@ class FlashCausalLM(Model):
|
|||
for bs in CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except Exception:
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
@ -874,22 +874,14 @@ class FlashCausalLM(Model):
|
|||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if (
|
||||
cu_seqlen_prefill is not None
|
||||
or cuda_graph is None
|
||||
or batch.speculative_ids is not None
|
||||
):
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
|
|
@ -432,12 +432,12 @@ class ResBlock(torch.nn.Module):
|
|||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, config, medusa_config, weights):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -447,12 +447,12 @@ class MedusaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, medusa_config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
for i in range(medusa_config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
|
@ -467,7 +467,7 @@ class MedusaHead(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
class MedusaHeadV1(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
|
@ -475,38 +475,147 @@ class SpeculativeHead(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, medusa_config, weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MedusaHeadV1(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHeadV2(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
||||
|
||||
assert medusa_config["medusa_num_layers"] == 1
|
||||
self.linear = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-1]
|
||||
block_size = (size + self.world_size - 1) // self.world_size
|
||||
start = self.rank * block_size
|
||||
stop = (self.rank + 1) * block_size
|
||||
|
||||
x_block = x[:, start:stop]
|
||||
|
||||
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||
medusa_res = self.act(self.linear(x)).reshape(
|
||||
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||
)
|
||||
|
||||
# Apply all residual medusa heads
|
||||
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||
|
||||
# Gather medusa heads
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
|
||||
# Stack x and medusa residual x
|
||||
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||
|
||||
# Compute lm head on x + medusa residual x
|
||||
logits = self.lm_head(stacked_x)
|
||||
|
||||
# Finally, split logits from speculative logits
|
||||
logits, speculative_logits = torch.split(
|
||||
logits, [1, self.n_medusa_heads], dim=-2
|
||||
)
|
||||
# Squeeze added dimension
|
||||
logits = logits.squeeze(-2)
|
||||
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
weights.routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, weights)
|
||||
lm_head = None
|
||||
try:
|
||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||
except:
|
||||
medusa = MedusaHeadV2(config, prefix, weights)
|
||||
else:
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||
return logits, speculative_logits
|
||||
if self.medusa is not None:
|
||||
return self.medusa(input)
|
||||
|
||||
assert self.head is not None
|
||||
logits = self.head(input)
|
||||
return logits, None
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
|
|
Loading…
Reference in New Issue