hotfix: fix flashllama
This commit is contained in:
parent
03c9388bf7
commit
27ff1871b5
|
@ -692,7 +692,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
# Used in Granite
|
# Used in Granite
|
||||||
if not self.logits_scaled:
|
if self.logits_scaling is not None and not self.logits_scaled:
|
||||||
logits /= self.logits_scaling
|
logits /= self.logits_scaling
|
||||||
if speculative_logits is not None:
|
if speculative_logits is not None:
|
||||||
speculative_logits /= self.logits_scaling
|
speculative_logits /= self.logits_scaling
|
||||||
|
|
Loading…
Reference in New Issue