diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index 3e68f6b..9f0416c 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -1,13 +1,13 @@ -use std::fs::File; +use float_eq::assert_float_eq; +use serde::Deserialize; use serde_json::Value; +use std::fs::File; use std::io::{BufRead, BufReader}; use std::path::PathBuf; use std::thread; use std::thread::sleep; use std::time::Duration; -use float_eq::assert_float_eq; use subprocess::{Popen, PopenConfig, Redirection}; -use serde::Deserialize; #[derive(Deserialize)] struct Details { @@ -22,7 +22,6 @@ struct GeneratedText { details: Details, } - fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { let argv = vec![ "text-generation-launcher".to_string(), @@ -46,7 +45,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port ..Default::default() }, ) - .expect("Could not start launcher"); + .expect("Could not start launcher"); // Redirect STDOUT and STDERR to the console let launcher_stdout = launcher.stdout.take().unwrap(); @@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port } }); - for _ in 0..30 { + for _ in 0..60 { let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); if health.is_ok() { return launcher; @@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port panic!("failed to launch {}", model_name) } -fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText { +fn test_model( + model_name: String, + num_shard: usize, + port: usize, + master_port: usize, +) -> GeneratedText { let mut launcher = start_launcher(model_name, num_shard, port, master_port); let data = r#" @@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us results.pop().unwrap() } - fn read_json(name: &str) -> GeneratedText { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); d.push("tests/"); @@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText { fn compare_results(result: GeneratedText, expected: GeneratedText) { assert_eq!(result.generated_text, expected.generated_text); assert_eq!(result.details.finish_reason, expected.details.finish_reason); - assert_eq!(result.details.generated_tokens, expected.details.generated_tokens); + assert_eq!( + result.details.generated_tokens, + expected.details.generated_tokens + ); - for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) { + for (token, expected_token) in result + .details + .tokens + .into_iter() + .zip(expected.details.tokens.into_iter()) + { assert_eq!(token.0, expected_token.0); assert_eq!(token.1, expected_token.1); if let Some(logprob) = token.2 { diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index ed6e896..643cb83 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -11,46 +11,33 @@ from text_generation.utils import ( def test_stop_sequence_criteria(): - criteria = StopSequenceCriteria([1, 2, 3]) + criteria = StopSequenceCriteria("/test;") - assert not criteria(1) - assert criteria.current_token_idx == 1 - assert not criteria(2) - assert criteria.current_token_idx == 2 - assert criteria(3) - assert criteria.current_token_idx == 3 - - -def test_stop_sequence_criteria_reset(): - criteria = StopSequenceCriteria([1, 2, 3]) - - assert not criteria(1) - assert criteria.current_token_idx == 1 - assert not criteria(2) - assert criteria.current_token_idx == 2 - assert not criteria(4) - assert criteria.current_token_idx == 0 - - -def test_stop_sequence_criteria_empty(): - with pytest.raises(ValueError): - StopSequenceCriteria([]) + assert not criteria("/") + assert not criteria("/test") + assert criteria("/test;") + assert not criteria("/test; ") def test_stopping_criteria(): - criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5) - assert criteria([1]) == (False, None) - assert criteria([1, 2]) == (False, None) - assert criteria([1, 2, 3]) == (True, "stop_sequence") + criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) + assert criteria(65827, "/test") == (False, None) + assert criteria(30, ";") == (True, "stop_sequence") + + +def test_stopping_criteria_eos(): + criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) + assert criteria(1, "") == (False, None) + assert criteria(0, "") == (True, "eos_token") def test_stopping_criteria_max(): - criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5) - assert criteria([1]) == (False, None) - assert criteria([1, 1]) == (False, None) - assert criteria([1, 1, 1]) == (False, None) - assert criteria([1, 1, 1, 1]) == (False, None) - assert criteria([1, 1, 1, 1, 1]) == (True, "length") + criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (False, None) + assert criteria(1, "") == (True, "length") def test_weight_hub_files(): diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 5d62137..aeecf12 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -345,7 +345,12 @@ class CausalLM(Model): all_logprobs = torch.cat([all_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(all_input_ids) + stop, reason = stopping_criteria( + next_token.squeeze(), + self.tokenizer.decode( + next_token.squeeze(), clean_up_tokenization_spaces=False + ), + ) if stop: # Decode all tokens output_text = self.tokenizer.decode( diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index e51ce60..fc80c60 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -441,7 +441,12 @@ class Seq2SeqLM(Model): decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(decoder_input_ids) + stop, reason = stopping_criteria( + next_token.squeeze(), + self.tokenizer.decode( + next_token.squeeze(), clean_up_tokenization_spaces=False + ), + ) if stop: # Slice with decoder_input_length to remove padding # Decode all tokens diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 9dd9115..cc77973 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -1,5 +1,6 @@ import concurrent import os +import re import torch import torch.distributed @@ -74,43 +75,39 @@ class NextTokenChooser: class StopSequenceCriteria: - def __init__(self, tokens: List[int]): - if not tokens: - raise ValueError("tokens cannot be empty") + def __init__(self, stop_sequence: str): + self.regex = re.compile(f".*{stop_sequence}$") - self.tokens = tokens - self.current_token_idx = 0 - - def __call__(self, last_token: int) -> bool: - if last_token == self.tokens[self.current_token_idx]: - # Increase idx to go to next token - self.current_token_idx += 1 - else: - # Reset to first token of the stopping sequence - self.current_token_idx = 0 - - if self.current_token_idx == len(self.tokens): - # We matched the entire sequence without resetting + def __call__(self, output: str) -> bool: + if self.regex.findall(output): return True return False class StoppingCriteria: def __init__( - self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 + self, + eos_token_id: int, + stop_sequence_criterias: List[StopSequenceCriteria], + max_new_tokens=20, ): + self.eos_token_id = eos_token_id self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens self.current_tokens = 0 + self.current_output = "" - def __call__(self, all_ids) -> Tuple[bool, Optional[str]]: + def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: return True, "length" - last_token = all_ids[-1] + if last_token == self.eos_token_id: + return True, "eos_token" + + self.current_output += last_output for stop_sequence_criteria in self.stop_sequence_criterias: - if stop_sequence_criteria(last_token): + if stop_sequence_criteria(self.current_output): return True, "stop_sequence" return False, None @@ -119,16 +116,12 @@ class StoppingCriteria: def from_pb( cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer ) -> "StoppingCriteria": - stop_sequence_criterias = [] - for stop_sequence in pb.stop_sequences: - tokens = tokenizer( - stop_sequence, padding=False, return_attention_mask=False - ).input_ids - if tokens: - stop_sequence_criterias.append(StopSequenceCriteria(tokens)) - stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id])) - - return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens) + stop_sequence_criterias = [ + StopSequenceCriteria(sequence) for sequence in pb.stop_sequences + ] + return StoppingCriteria( + tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens + ) def initialize_torch_distributed():