2023-03-07 10:52:22 -07:00
|
|
|
from text_generation_server.utils.tokens import (
|
2022-12-12 10:25:22 -07:00
|
|
|
StopSequenceCriteria,
|
|
|
|
StoppingCriteria,
|
2023-02-03 04:43:37 -07:00
|
|
|
FinishReason,
|
2022-12-08 10:49:33 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-12-12 10:25:22 -07:00
|
|
|
def test_stop_sequence_criteria():
|
2022-12-16 08:03:39 -07:00
|
|
|
criteria = StopSequenceCriteria("/test;")
|
2022-12-12 10:25:22 -07:00
|
|
|
|
2022-12-16 08:03:39 -07:00
|
|
|
assert not criteria("/")
|
|
|
|
assert not criteria("/test")
|
|
|
|
assert criteria("/test;")
|
|
|
|
assert not criteria("/test; ")
|
2022-12-12 10:25:22 -07:00
|
|
|
|
|
|
|
|
2023-04-05 11:37:41 -06:00
|
|
|
def test_stop_sequence_criteria_escape():
|
|
|
|
criteria = StopSequenceCriteria("<|stop|>")
|
|
|
|
|
|
|
|
assert not criteria("<")
|
|
|
|
assert not criteria("<|stop")
|
|
|
|
assert criteria("<|stop|>")
|
|
|
|
assert not criteria("<|stop|> ")
|
|
|
|
|
|
|
|
|
2022-12-16 08:03:39 -07:00
|
|
|
def test_stopping_criteria():
|
|
|
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
|
|
|
assert criteria(65827, "/test") == (False, None)
|
2023-02-03 04:43:37 -07:00
|
|
|
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
|
2022-12-12 10:25:22 -07:00
|
|
|
|
|
|
|
|
2022-12-16 08:03:39 -07:00
|
|
|
def test_stopping_criteria_eos():
|
|
|
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
|
|
|
assert criteria(1, "") == (False, None)
|
2023-02-03 04:43:37 -07:00
|
|
|
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
|
2022-12-12 10:25:22 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_stopping_criteria_max():
|
2022-12-16 08:03:39 -07:00
|
|
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
|
|
|
assert criteria(1, "") == (False, None)
|
|
|
|
assert criteria(1, "") == (False, None)
|
|
|
|
assert criteria(1, "") == (False, None)
|
|
|
|
assert criteria(1, "") == (False, None)
|
2023-02-03 04:43:37 -07:00
|
|
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|