Hotfixing qwen2 and starcoder2 (which also get clamping). (#2167)

This commit is contained in:
Nicolas Patry 2024-07-02 14:26:47 +02:00 committed by GitHub
parent 963b6c6f0f
commit 0759ec495e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -368,7 +368,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -534,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
elif self.max_past is not None: elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention # Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values # kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) input_lengths = input_lengths.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,