fix(server): Fix stop sequences (#11)
This commit is contained in:
parent
3e2e6240b8
commit
611e21cb13
|
@ -1,13 +1,13 @@
|
|||
use std::fs::File;
|
||||
use float_eq::assert_float_eq;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::PathBuf;
|
||||
use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
use float_eq::assert_float_eq;
|
||||
use subprocess::{Popen, PopenConfig, Redirection};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Details {
|
||||
|
@ -22,7 +22,6 @@ struct GeneratedText {
|
|||
details: Details,
|
||||
}
|
||||
|
||||
|
||||
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||
let argv = vec![
|
||||
"text-generation-launcher".to_string(),
|
||||
|
@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||
}
|
||||
});
|
||||
|
||||
for _ in 0..30 {
|
||||
for _ in 0..60 {
|
||||
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
|
||||
if health.is_ok() {
|
||||
return launcher;
|
||||
|
@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
|||
panic!("failed to launch {}", model_name)
|
||||
}
|
||||
|
||||
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
|
||||
fn test_model(
|
||||
model_name: String,
|
||||
num_shard: usize,
|
||||
port: usize,
|
||||
master_port: usize,
|
||||
) -> GeneratedText {
|
||||
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
||||
|
||||
let data = r#"
|
||||
|
@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
|
|||
results.pop().unwrap()
|
||||
}
|
||||
|
||||
|
||||
fn read_json(name: &str) -> GeneratedText {
|
||||
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
d.push("tests/");
|
||||
|
@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
|
|||
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||
assert_eq!(result.generated_text, expected.generated_text);
|
||||
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
|
||||
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
|
||||
assert_eq!(
|
||||
result.details.generated_tokens,
|
||||
expected.details.generated_tokens
|
||||
);
|
||||
|
||||
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
|
||||
for (token, expected_token) in result
|
||||
.details
|
||||
.tokens
|
||||
.into_iter()
|
||||
.zip(expected.details.tokens.into_iter())
|
||||
{
|
||||
assert_eq!(token.0, expected_token.0);
|
||||
assert_eq!(token.1, expected_token.1);
|
||||
if let Some(logprob) = token.2 {
|
||||
|
|
|
@ -11,46 +11,33 @@ from text_generation.utils import (
|
|||
|
||||
|
||||
def test_stop_sequence_criteria():
|
||||
criteria = StopSequenceCriteria([1, 2, 3])
|
||||
criteria = StopSequenceCriteria("/test;")
|
||||
|
||||
assert not criteria(1)
|
||||
assert criteria.current_token_idx == 1
|
||||
assert not criteria(2)
|
||||
assert criteria.current_token_idx == 2
|
||||
assert criteria(3)
|
||||
assert criteria.current_token_idx == 3
|
||||
|
||||
|
||||
def test_stop_sequence_criteria_reset():
|
||||
criteria = StopSequenceCriteria([1, 2, 3])
|
||||
|
||||
assert not criteria(1)
|
||||
assert criteria.current_token_idx == 1
|
||||
assert not criteria(2)
|
||||
assert criteria.current_token_idx == 2
|
||||
assert not criteria(4)
|
||||
assert criteria.current_token_idx == 0
|
||||
|
||||
|
||||
def test_stop_sequence_criteria_empty():
|
||||
with pytest.raises(ValueError):
|
||||
StopSequenceCriteria([])
|
||||
assert not criteria("/")
|
||||
assert not criteria("/test")
|
||||
assert criteria("/test;")
|
||||
assert not criteria("/test; ")
|
||||
|
||||
|
||||
def test_stopping_criteria():
|
||||
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
||||
assert criteria([1]) == (False, None)
|
||||
assert criteria([1, 2]) == (False, None)
|
||||
assert criteria([1, 2, 3]) == (True, "stop_sequence")
|
||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||
assert criteria(65827, "/test") == (False, None)
|
||||
assert criteria(30, ";") == (True, "stop_sequence")
|
||||
|
||||
|
||||
def test_stopping_criteria_eos():
|
||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(0, "") == (True, "eos_token")
|
||||
|
||||
|
||||
def test_stopping_criteria_max():
|
||||
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
|
||||
assert criteria([1]) == (False, None)
|
||||
assert criteria([1, 1]) == (False, None)
|
||||
assert criteria([1, 1, 1]) == (False, None)
|
||||
assert criteria([1, 1, 1, 1]) == (False, None)
|
||||
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
|
||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (True, "length")
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
|
|
|
@ -345,7 +345,12 @@ class CausalLM(Model):
|
|||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(all_input_ids)
|
||||
stop, reason = stopping_criteria(
|
||||
next_token.squeeze(),
|
||||
self.tokenizer.decode(
|
||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||
),
|
||||
)
|
||||
if stop:
|
||||
# Decode all tokens
|
||||
output_text = self.tokenizer.decode(
|
||||
|
|
|
@ -441,7 +441,12 @@ class Seq2SeqLM(Model):
|
|||
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(decoder_input_ids)
|
||||
stop, reason = stopping_criteria(
|
||||
next_token.squeeze(),
|
||||
self.tokenizer.decode(
|
||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||
),
|
||||
)
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import concurrent
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -74,43 +75,39 @@ class NextTokenChooser:
|
|||
|
||||
|
||||
class StopSequenceCriteria:
|
||||
def __init__(self, tokens: List[int]):
|
||||
if not tokens:
|
||||
raise ValueError("tokens cannot be empty")
|
||||
def __init__(self, stop_sequence: str):
|
||||
self.regex = re.compile(f".*{stop_sequence}$")
|
||||
|
||||
self.tokens = tokens
|
||||
self.current_token_idx = 0
|
||||
|
||||
def __call__(self, last_token: int) -> bool:
|
||||
if last_token == self.tokens[self.current_token_idx]:
|
||||
# Increase idx to go to next token
|
||||
self.current_token_idx += 1
|
||||
else:
|
||||
# Reset to first token of the stopping sequence
|
||||
self.current_token_idx = 0
|
||||
|
||||
if self.current_token_idx == len(self.tokens):
|
||||
# We matched the entire sequence without resetting
|
||||
def __call__(self, output: str) -> bool:
|
||||
if self.regex.findall(output):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StoppingCriteria:
|
||||
def __init__(
|
||||
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
|
||||
self,
|
||||
eos_token_id: int,
|
||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||
max_new_tokens=20,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
self.stop_sequence_criterias = stop_sequence_criterias
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.current_tokens = 0
|
||||
self.current_output = ""
|
||||
|
||||
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
|
||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||
self.current_tokens += 1
|
||||
if self.current_tokens >= self.max_new_tokens:
|
||||
return True, "length"
|
||||
|
||||
last_token = all_ids[-1]
|
||||
if last_token == self.eos_token_id:
|
||||
return True, "eos_token"
|
||||
|
||||
self.current_output += last_output
|
||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||
if stop_sequence_criteria(last_token):
|
||||
if stop_sequence_criteria(self.current_output):
|
||||
return True, "stop_sequence"
|
||||
|
||||
return False, None
|
||||
|
@ -119,16 +116,12 @@ class StoppingCriteria:
|
|||
def from_pb(
|
||||
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
|
||||
) -> "StoppingCriteria":
|
||||
stop_sequence_criterias = []
|
||||
for stop_sequence in pb.stop_sequences:
|
||||
tokens = tokenizer(
|
||||
stop_sequence, padding=False, return_attention_mask=False
|
||||
).input_ids
|
||||
if tokens:
|
||||
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
|
||||
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
|
||||
|
||||
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
]
|
||||
return StoppingCriteria(
|
||||
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
||||
)
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
|
|
Loading…
Reference in New Issue