fix: add adapter_data param to phi and neox

This commit is contained in:
drbh 2024-06-07 03:28:15 +00:00
parent b1169273fd
commit 1deb372564
2 changed files with 2 additions and 0 deletions

View File

@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(
input_ids, input_ids,

View File

@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,