Hotfixing qwen2 and starcoder2 (which also get clamping). (#2167)
This commit is contained in:
parent
963b6c6f0f
commit
0759ec495e
|
@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
|
@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
Loading…
Reference in New Issue