Update for new API
This commit is contained in:
parent
0721649fc6
commit
b49cff3f07
|
@ -513,7 +513,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||||
def _advance(next_token_id, fsm_grammar_state, fsm):
|
def _advance(next_token_id, fsm_grammar_state, fsm):
|
||||||
if fsm_grammar_state == -1:
|
if fsm_grammar_state == -1:
|
||||||
return fsm_grammar_state
|
return fsm_grammar_state
|
||||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
return fsm.get_next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
# TODO: move grammar compilation into the router
|
# TODO: move grammar compilation into the router
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -588,7 +588,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
fsm = self.fsms[i]
|
fsm = self.fsms[i]
|
||||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
continue
|
continue
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
||||||
mask[i, allowed_tokens] = 0
|
mask[i, allowed_tokens] = 0
|
||||||
logits[i] += mask[i]
|
logits[i] += mask[i]
|
||||||
return logits
|
return logits
|
||||||
|
|
Loading…
Reference in New Issue