Tied embeddings in MLP speculator.

This commit is contained in:
Nicolas Patry 2024-08-29 12:30:26 +02:00
parent 5e2932552c
commit 5838f2139f
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 93 additions and 1 deletions

View File

@ -45,12 +45,95 @@ class MLPSpeculatorLayerNorm(nn.Module):
return x
class MLPSpeculatorModelTied(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
self.proj0 = FastLinear.load(
config,
prefix=f"{prefix}.proj.0",
weights=weights,
bias=False,
)
self.proj1 = FastLinear.load(
config,
prefix=f"{prefix}.proj.1",
weights=weights,
bias=False,
)
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
self.ln = MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.0",
config=config,
weights=weights,
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()
# TODO
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb(ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
if i == 0:
state = self.proj0(state) * self.state_weight + z
else:
state = self.proj1(state) * self.state_weight + z
state = self.activation(self.ln(state)) # b k d
probs = F.log_softmax(self.head(state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
@ -171,6 +254,10 @@ class MLPSpeculatorHead(nn.Module):
)
routing[k] = filename
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
tie_weights = config.speculator_config.get("tie_weights", False)
if tie_weights:
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
else:
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator)

View File

@ -458,6 +458,11 @@ def get_model(
revision=mlp_revision,
filename=filename,
)
speculator_dir_path = Path(mlp_speculator_config).parent
# if these are downloaded, they get converted to safetensors
filenames.extend(
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,