fix(server): Fix stop sequences (#11)

This commit is contained in:
OlivierDehaene 2022-12-16 16:03:39 +01:00 committed by GitHub
parent 3e2e6240b8
commit 611e21cb13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 75 deletions

View File

@ -1,13 +1,13 @@
use std::fs::File; use float_eq::assert_float_eq;
use serde::Deserialize;
use serde_json::Value; use serde_json::Value;
use std::fs::File;
use std::io::{BufRead, BufReader}; use std::io::{BufRead, BufReader};
use std::path::PathBuf; use std::path::PathBuf;
use std::thread; use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::Duration; use std::time::Duration;
use float_eq::assert_float_eq;
use subprocess::{Popen, PopenConfig, Redirection}; use subprocess::{Popen, PopenConfig, Redirection};
use serde::Deserialize;
#[derive(Deserialize)] #[derive(Deserialize)]
struct Details { struct Details {
@ -22,7 +22,6 @@ struct GeneratedText {
details: Details, details: Details,
} }
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![ let argv = vec![
"text-generation-launcher".to_string(), "text-generation-launcher".to_string(),
@ -46,7 +45,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
..Default::default() ..Default::default()
}, },
) )
.expect("Could not start launcher"); .expect("Could not start launcher");
// Redirect STDOUT and STDERR to the console // Redirect STDOUT and STDERR to the console
let launcher_stdout = launcher.stdout.take().unwrap(); let launcher_stdout = launcher.stdout.take().unwrap();
@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
} }
}); });
for _ in 0..30 { for _ in 0..60 {
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
if health.is_ok() { if health.is_ok() {
return launcher; return launcher;
@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
panic!("failed to launch {}", model_name) panic!("failed to launch {}", model_name)
} }
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText { fn test_model(
model_name: String,
num_shard: usize,
port: usize,
master_port: usize,
) -> GeneratedText {
let mut launcher = start_launcher(model_name, num_shard, port, master_port); let mut launcher = start_launcher(model_name, num_shard, port, master_port);
let data = r#" let data = r#"
@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
results.pop().unwrap() results.pop().unwrap()
} }
fn read_json(name: &str) -> GeneratedText { fn read_json(name: &str) -> GeneratedText {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("tests/"); d.push("tests/");
@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
fn compare_results(result: GeneratedText, expected: GeneratedText) { fn compare_results(result: GeneratedText, expected: GeneratedText) {
assert_eq!(result.generated_text, expected.generated_text); assert_eq!(result.generated_text, expected.generated_text);
assert_eq!(result.details.finish_reason, expected.details.finish_reason); assert_eq!(result.details.finish_reason, expected.details.finish_reason);
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens); assert_eq!(
result.details.generated_tokens,
expected.details.generated_tokens
);
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) { for (token, expected_token) in result
.details
.tokens
.into_iter()
.zip(expected.details.tokens.into_iter())
{
assert_eq!(token.0, expected_token.0); assert_eq!(token.0, expected_token.0);
assert_eq!(token.1, expected_token.1); assert_eq!(token.1, expected_token.1);
if let Some(logprob) = token.2 { if let Some(logprob) = token.2 {

View File

@ -11,46 +11,33 @@ from text_generation.utils import (
def test_stop_sequence_criteria(): def test_stop_sequence_criteria():
criteria = StopSequenceCriteria([1, 2, 3]) criteria = StopSequenceCriteria("/test;")
assert not criteria(1) assert not criteria("/")
assert criteria.current_token_idx == 1 assert not criteria("/test")
assert not criteria(2) assert criteria("/test;")
assert criteria.current_token_idx == 2 assert not criteria("/test; ")
assert criteria(3)
assert criteria.current_token_idx == 3
def test_stop_sequence_criteria_reset():
criteria = StopSequenceCriteria([1, 2, 3])
assert not criteria(1)
assert criteria.current_token_idx == 1
assert not criteria(2)
assert criteria.current_token_idx == 2
assert not criteria(4)
assert criteria.current_token_idx == 0
def test_stop_sequence_criteria_empty():
with pytest.raises(ValueError):
StopSequenceCriteria([])
def test_stopping_criteria(): def test_stopping_criteria():
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5) criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria([1]) == (False, None) assert criteria(65827, "/test") == (False, None)
assert criteria([1, 2]) == (False, None) assert criteria(30, ";") == (True, "stop_sequence")
assert criteria([1, 2, 3]) == (True, "stop_sequence")
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, "eos_token")
def test_stopping_criteria_max(): def test_stopping_criteria_max():
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5) criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria([1]) == (False, None) assert criteria(1, "") == (False, None)
assert criteria([1, 1]) == (False, None) assert criteria(1, "") == (False, None)
assert criteria([1, 1, 1]) == (False, None) assert criteria(1, "") == (False, None)
assert criteria([1, 1, 1, 1]) == (False, None) assert criteria(1, "") == (False, None)
assert criteria([1, 1, 1, 1, 1]) == (True, "length") assert criteria(1, "") == (True, "length")
def test_weight_hub_files(): def test_weight_hub_files():

View File

@ -345,7 +345,12 @@ class CausalLM(Model):
all_logprobs = torch.cat([all_logprobs, next_token_logprob]) all_logprobs = torch.cat([all_logprobs, next_token_logprob])
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria(all_input_ids) stop, reason = stopping_criteria(
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
)
if stop: if stop:
# Decode all tokens # Decode all tokens
output_text = self.tokenizer.decode( output_text = self.tokenizer.decode(

View File

@ -441,7 +441,12 @@ class Seq2SeqLM(Model):
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria(decoder_input_ids) stop, reason = stopping_criteria(
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
)
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens

View File

@ -1,5 +1,6 @@
import concurrent import concurrent
import os import os
import re
import torch import torch
import torch.distributed import torch.distributed
@ -74,43 +75,39 @@ class NextTokenChooser:
class StopSequenceCriteria: class StopSequenceCriteria:
def __init__(self, tokens: List[int]): def __init__(self, stop_sequence: str):
if not tokens: self.regex = re.compile(f".*{stop_sequence}$")
raise ValueError("tokens cannot be empty")
self.tokens = tokens def __call__(self, output: str) -> bool:
self.current_token_idx = 0 if self.regex.findall(output):
def __call__(self, last_token: int) -> bool:
if last_token == self.tokens[self.current_token_idx]:
# Increase idx to go to next token
self.current_token_idx += 1
else:
# Reset to first token of the stopping sequence
self.current_token_idx = 0
if self.current_token_idx == len(self.tokens):
# We matched the entire sequence without resetting
return True return True
return False return False
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
): ):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
self.current_output = ""
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]: def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1 self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True, "length" return True, "length"
last_token = all_ids[-1] if last_token == self.eos_token_id:
return True, "eos_token"
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias: for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(last_token): if stop_sequence_criteria(self.current_output):
return True, "stop_sequence" return True, "stop_sequence"
return False, None return False, None
@ -119,16 +116,12 @@ class StoppingCriteria:
def from_pb( def from_pb(
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [] stop_sequence_criterias = [
for stop_sequence in pb.stop_sequences: StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
tokens = tokenizer( ]
stop_sequence, padding=False, return_attention_mask=False return StoppingCriteria(
).input_ids tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
if tokens: )
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
def initialize_torch_distributed(): def initialize_torch_distributed():