Hotfixing qwen2 and starcoder2 (which also get clamping). (#2167)
This commit is contained in:
parent
963b6c6f0f
commit
0759ec495e
|
@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
Loading…
Reference in New Issue