Tied embeddings in MLP speculator. (#2473)
* Tied embeddings in MLP speculator. * Fixing the scale_weight when users decide to not use the speculation as much as defined in the config. * Adding scaling support + optimize some ops.
This commit is contained in:
parent
9883f3b40e
commit
d9fbbaafb0
|
@ -45,12 +45,107 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
INV_SQRT2 = 2**-0.5
|
||||
|
||||
|
||||
def simple_norm(x: torch.Tensor, eps=1e-06):
|
||||
xf = x
|
||||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
|
||||
x = xf.type_as(x)
|
||||
return x * INV_SQRT2
|
||||
|
||||
|
||||
class MLPSpeculatorModelTied(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.n_predict = get_speculate()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
|
||||
self.proj0 = FastLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.proj.0",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.proj1 = FastLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.proj.1",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
|
||||
self.ln = MLPSpeculatorLayerNorm(
|
||||
prefix=f"{prefix}.ln.0",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||
self.activation = nn.GELU()
|
||||
self.vsize = config.vocab_size
|
||||
self.inner_dim = config.speculator_config["inner_dim"]
|
||||
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||
self.inner_dim / 2
|
||||
)
|
||||
self.emb.weight *= self.emb_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||
|
||||
# k indicates # of candidates
|
||||
# h indicates # of generated tokens
|
||||
state = hidden_states
|
||||
b = state.size(0)
|
||||
ind = input_ids.unsqueeze(0)
|
||||
all_probs = torch.empty(
|
||||
b, self.n_predict, self.vsize, device=state.device
|
||||
) # b k h v
|
||||
assert (
|
||||
len(top_k_tokens_per_head) == self.n_predict
|
||||
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||
for i in range(self.n_predict):
|
||||
# Project and predict
|
||||
z = self.emb(ind)
|
||||
# z = z.mul(self.emb_weight) # b k d
|
||||
if i == 0:
|
||||
state = self.proj0(state) * self.state_weight + z
|
||||
else:
|
||||
state = self.proj1(state) * self.state_weight + z
|
||||
state = self.activation(self.ln(state)) # b k d
|
||||
probs = F.log_softmax(self.head(state), dim=-1) # b k v
|
||||
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||
|
||||
# Update candidate set with new predictions
|
||||
|
||||
# Update distribution set with new logits
|
||||
all_probs[:, i] = probs.exp()
|
||||
|
||||
# Update state, log_probs and ind for new predictions
|
||||
state = state.unsqueeze(2).expand(
|
||||
-1, -1, top_k_tokens_per_head[i], -1
|
||||
) # b k k' d
|
||||
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||
ind = preds.view(-1, b) # b kk'
|
||||
|
||||
speculative_logits = all_probs
|
||||
return speculative_logits
|
||||
|
||||
|
||||
class MLPSpeculatorModel(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.n_predict = get_speculate()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.emb = nn.ModuleList(
|
||||
[
|
||||
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||
|
@ -84,13 +179,15 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||
)
|
||||
|
||||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
||||
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||
self.activation = nn.GELU()
|
||||
# TODO
|
||||
self.vsize = config.vocab_size
|
||||
self.inner_dim = config.speculator_config["inner_dim"]
|
||||
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||
self.inner_dim / 2
|
||||
)
|
||||
self.emb.weight *= self.emb_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -113,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||
for i in range(self.n_predict):
|
||||
# Project and predict
|
||||
z = self.emb[i](ind)
|
||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||
# z = z.mul(self.emb_weight) # b k d
|
||||
state = self.proj[i](state) * self.state_weight + z
|
||||
state = self.activation(self.ln[i](state)) # b k d
|
||||
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||
|
@ -136,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||
|
||||
|
||||
class MLPSpeculatorHead(nn.Module):
|
||||
def __init__(self, lm_head, mlp_speculator):
|
||||
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
self.mlp_speculator = mlp_speculator
|
||||
self.scale_input = scale_input
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
|
@ -150,6 +248,8 @@ class MLPSpeculatorHead(nn.Module):
|
|||
return logits, None
|
||||
|
||||
input_ids = logits.argmax(dim=-1)
|
||||
if self.scale_input:
|
||||
input = simple_norm(input)
|
||||
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
@ -171,6 +271,12 @@ class MLPSpeculatorHead(nn.Module):
|
|||
)
|
||||
routing[k] = filename
|
||||
|
||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||
tie_weights = config.speculator_config.get("tie_weights", False)
|
||||
if tie_weights:
|
||||
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
||||
else:
|
||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
|
||||
scale_input = config.speculator_config.get("scale_input", False)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
||||
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
|
||||
|
|
|
@ -458,6 +458,11 @@ def get_model(
|
|||
revision=mlp_revision,
|
||||
filename=filename,
|
||||
)
|
||||
speculator_dir_path = Path(mlp_speculator_config).parent
|
||||
# if these are downloaded, they get converted to safetensors
|
||||
filenames.extend(
|
||||
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
|
||||
)
|
||||
speculator = {
|
||||
"path": Path(mlp_speculator_config).parent,
|
||||
"model_paths": filenames,
|
||||
|
|
Loading…
Reference in New Issue