fix(server): Fix stop sequences (#11)

This commit is contained in:
OlivierDehaene 2022-12-16 16:03:39 +01:00 committed by GitHub
parent 3e2e6240b8
commit 611e21cb13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 75 deletions

View File

@ -1,13 +1,13 @@
use std::fs::File;
use float_eq::assert_float_eq;
use serde::Deserialize;
use serde_json::Value;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use float_eq::assert_float_eq;
use subprocess::{Popen, PopenConfig, Redirection};
use serde::Deserialize;
#[derive(Deserialize)]
struct Details {
@ -22,7 +22,6 @@ struct GeneratedText {
details: Details,
}
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
"text-generation-launcher".to_string(),
@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
}
});
for _ in 0..30 {
for _ in 0..60 {
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
if health.is_ok() {
return launcher;
@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
panic!("failed to launch {}", model_name)
}
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
fn test_model(
model_name: String,
num_shard: usize,
port: usize,
master_port: usize,
) -> GeneratedText {
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
let data = r#"
@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
results.pop().unwrap()
}
fn read_json(name: &str) -> GeneratedText {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("tests/");
@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
fn compare_results(result: GeneratedText, expected: GeneratedText) {
assert_eq!(result.generated_text, expected.generated_text);
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
assert_eq!(
result.details.generated_tokens,
expected.details.generated_tokens
);
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
for (token, expected_token) in result
.details
.tokens
.into_iter()
.zip(expected.details.tokens.into_iter())
{
assert_eq!(token.0, expected_token.0);
assert_eq!(token.1, expected_token.1);
if let Some(logprob) = token.2 {

View File

@ -11,46 +11,33 @@ from text_generation.utils import (
def test_stop_sequence_criteria():
criteria = StopSequenceCriteria([1, 2, 3])
criteria = StopSequenceCriteria("/test;")
assert not criteria(1)
assert criteria.current_token_idx == 1
assert not criteria(2)
assert criteria.current_token_idx == 2
assert criteria(3)
assert criteria.current_token_idx == 3
def test_stop_sequence_criteria_reset():
criteria = StopSequenceCriteria([1, 2, 3])
assert not criteria(1)
assert criteria.current_token_idx == 1
assert not criteria(2)
assert criteria.current_token_idx == 2
assert not criteria(4)
assert criteria.current_token_idx == 0
def test_stop_sequence_criteria_empty():
with pytest.raises(ValueError):
StopSequenceCriteria([])
assert not criteria("/")
assert not criteria("/test")
assert criteria("/test;")
assert not criteria("/test; ")
def test_stopping_criteria():
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
assert criteria([1]) == (False, None)
assert criteria([1, 2]) == (False, None)
assert criteria([1, 2, 3]) == (True, "stop_sequence")
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, "stop_sequence")
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, "eos_token")
def test_stopping_criteria_max():
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
assert criteria([1]) == (False, None)
assert criteria([1, 1]) == (False, None)
assert criteria([1, 1, 1]) == (False, None)
assert criteria([1, 1, 1, 1]) == (False, None)
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, "length")
def test_weight_hub_files():

View File

@ -345,7 +345,12 @@ class CausalLM(Model):
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
# Evaluate stopping criteria
stop, reason = stopping_criteria(all_input_ids)
stop, reason = stopping_criteria(
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
)
if stop:
# Decode all tokens
output_text = self.tokenizer.decode(

View File

@ -441,7 +441,12 @@ class Seq2SeqLM(Model):
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
# Evaluate stopping criteria
stop, reason = stopping_criteria(decoder_input_ids)
stop, reason = stopping_criteria(
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
)
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens

View File

@ -1,5 +1,6 @@
import concurrent
import os
import re
import torch
import torch.distributed
@ -74,43 +75,39 @@ class NextTokenChooser:
class StopSequenceCriteria:
def __init__(self, tokens: List[int]):
if not tokens:
raise ValueError("tokens cannot be empty")
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
self.tokens = tokens
self.current_token_idx = 0
def __call__(self, last_token: int) -> bool:
if last_token == self.tokens[self.current_token_idx]:
# Increase idx to go to next token
self.current_token_idx += 1
else:
# Reset to first token of the stopping sequence
self.current_token_idx = 0
if self.current_token_idx == len(self.tokens):
# We matched the entire sequence without resetting
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, "length"
last_token = all_ids[-1]
if last_token == self.eos_token_id:
return True, "eos_token"
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(last_token):
if stop_sequence_criteria(self.current_output):
return True, "stop_sequence"
return False, None
@ -119,16 +116,12 @@ class StoppingCriteria:
def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
) -> "StoppingCriteria":
stop_sequence_criterias = []
for stop_sequence in pb.stop_sequences:
tokens = tokenizer(
stop_sequence, padding=False, return_attention_mask=False
).input_ids
if tokens:
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
def initialize_torch_distributed():