hotfix: fix flashllama

This commit is contained in:
OlivierDehaene 2024-10-23 13:22:31 +02:00
parent 03c9388bf7
commit 27ff1871b5
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
1 changed files with 1 additions and 1 deletions

View File

@ -692,7 +692,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite # Used in Granite
if not self.logits_scaled: if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling logits /= self.logits_scaling
if speculative_logits is not None: if speculative_logits is not None:
speculative_logits /= self.logits_scaling speculative_logits /= self.logits_scaling