fix(server): fix OPT implementation (#2061)
This commit is contained in:
parent
376a0b7ada
commit
521de6cacd
|
@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits, speculative_logits = self.lm_head(outputs)
|
||||
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||
|
||||
loss = None
|
||||
|
||||
|
|
|
@ -85,5 +85,4 @@ class GPTNeoxSharded(CausalLM):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, speculative_logits, outputs.past_key_values
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
|
@ -75,11 +75,11 @@ class OPTSharded(CausalLM):
|
|||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, outputs.past_key_values
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
|
@ -71,11 +71,13 @@ class RW(CausalLM):
|
|||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
):
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
Loading…
Reference in New Issue