hotfix: fix flashllama
This commit is contained in:
parent
03c9388bf7
commit
27ff1871b5
|
@ -692,7 +692,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
# Used in Granite
|
||||
if not self.logits_scaled:
|
||||
if self.logits_scaling is not None and not self.logits_scaled:
|
||||
logits /= self.logits_scaling
|
||||
if speculative_logits is not None:
|
||||
speculative_logits /= self.logits_scaling
|
||||
|
|
Loading…
Reference in New Issue