fix(server): fix OPT implementation (#2061)

This commit is contained in:
OlivierDehaene 2024-06-12 18:22:20 +02:00 committed by GitHub
parent 376a0b7ada
commit 521de6cacd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 9 additions and 8 deletions

View File

@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
return_dict=return_dict,
)
logits, speculative_logits = self.lm_head(outputs)
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
loss = None

View File

@ -85,5 +85,4 @@ class GPTNeoxSharded(CausalLM):
use_cache=True,
)
logits = outputs.logits
return logits, speculative_logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -75,11 +75,11 @@ class OPTSharded(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -71,11 +71,13 @@ class RW(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
):
# Model Forward
outputs = self.model.forward(
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
return outputs.logits, speculative_logits, outputs.past_key_values