Tied embeddings in MLP speculator.
This commit is contained in:
parent
5e2932552c
commit
5838f2139f
|
@ -45,12 +45,95 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModelTied(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
|
||||||
|
self.proj0 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.0",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.proj1 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.1",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
|
||||||
|
self.ln = MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.0",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
# TODO
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb(ind)
|
||||||
|
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||||
|
if i == 0:
|
||||||
|
state = self.proj0(state) * self.state_weight + z
|
||||||
|
else:
|
||||||
|
state = self.proj1(state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln(state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head(state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorModel(torch.nn.Module):
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.n_predict = get_speculate()
|
self.n_predict = get_speculate()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.emb = nn.ModuleList(
|
self.emb = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||||
|
@ -171,6 +254,10 @@ class MLPSpeculatorHead(nn.Module):
|
||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
|
tie_weights = config.speculator_config.get("tie_weights", False)
|
||||||
|
if tie_weights:
|
||||||
|
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
||||||
|
else:
|
||||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
||||||
|
|
|
@ -458,6 +458,11 @@ def get_model(
|
||||||
revision=mlp_revision,
|
revision=mlp_revision,
|
||||||
filename=filename,
|
filename=filename,
|
||||||
)
|
)
|
||||||
|
speculator_dir_path = Path(mlp_speculator_config).parent
|
||||||
|
# if these are downloaded, they get converted to safetensors
|
||||||
|
filenames.extend(
|
||||||
|
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
|
||||||
|
)
|
||||||
speculator = {
|
speculator = {
|
||||||
"path": Path(mlp_speculator_config).parent,
|
"path": Path(mlp_speculator_config).parent,
|
||||||
"model_paths": filenames,
|
"model_paths": filenames,
|
||||||
|
|
Loading…
Reference in New Issue