feat(launcher): Add integration tests (#9)

This commit is contained in:
OlivierDehaene 2022-12-16 11:29:36 +01:00 committed by GitHub
parent 32a253063d
commit 3e2e6240b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 434 additions and 14 deletions

View File

@ -5,6 +5,8 @@ on:
paths:
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
jobs:
run_tests:
@ -15,16 +17,25 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: 3.9
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: 1.65.0
override: true
components: rustfmt, clippy
- name: Loading cache.
uses: actions/cache@v2
id: model_cache
with:
path: ~/.cache/huggingface/
key: models
- name: Install server dependencies
- name: Install
run: |
make install-server
- name: Run tests
make install
- name: Run server tests
run: |
pip install pytest
pytest -sv server/tests
- name: Run Rust tests
run: |
cargo test

30
Cargo.lock generated
View File

@ -543,6 +543,12 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "float_eq"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
[[package]]
name = "fnv"
version = "1.0.7"
@ -1505,9 +1511,9 @@ dependencies = [
[[package]]
name = "reqwest"
version = "0.11.12"
version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc"
checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c"
dependencies = [
"base64",
"bytes",
@ -1587,18 +1593,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.147"
version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.147"
version = "1.0.150"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e"
dependencies = [
"proc-macro2",
"quote",
@ -1607,9 +1613,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.87"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db"
dependencies = [
"itoa",
"ryu",
@ -1724,9 +1730,9 @@ dependencies = [
[[package]]
name = "syn"
version = "1.0.103"
version = "1.0.105"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908"
dependencies = [
"proc-macro2",
"quote",
@ -1804,6 +1810,10 @@ version = "0.1.0"
dependencies = [
"clap 4.0.22",
"ctrlc",
"float_eq",
"reqwest",
"serde",
"serde_json",
"subprocess",
"tracing",
"tracing-subscriber",

View File

@ -7,7 +7,13 @@ description = "Text Generation Launcher"
[dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = "3.2.3"
ctrlc = { version = "3.2.3", features = ["termination"] }
subprocess = "0.2.9"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = "1.0.150"
serde_json = "1.0.89"

View File

@ -0,0 +1,121 @@
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"tokens": [
[
10264,
"Test",
null
],
[
8821,
" request",
-11.895094
],
[
17,
".",
-1.8267941
],
[
1587,
"get",
-2.4674964
],
[
11,
"(",
-1.9060438
],
[
5,
"\"",
-1.2279553
],
[
4899,
"action",
-4.170306
],
[
5,
"\"",
-0.3247902
],
[
12,
")",
-1.0773602
],
[
30,
";",
-0.27640444
],
[
837,
"\n ",
-1.6970599
],
[
1320,
" if",
-1.4495552
],
[
375,
" (",
-0.2360998
],
[
4899,
"action",
-1.1916926
],
[
3535,
" ==",
-0.8918663
],
[
5109,
" null",
-0.39334255
],
[
12,
")",
-0.4321134
],
[
731,
" {",
-0.17701954
],
[
1260,
"\n ",
-0.07027287
],
[
10519,
" throw",
-1.3915133
],
[
2084,
" new",
-0.042013377
],
[
150858,
" RuntimeException",
-1.7330077
]
]
},
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
]

View File

@ -0,0 +1,156 @@
use std::fs::File;
use serde_json::Value;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use float_eq::assert_float_eq;
use subprocess::{Popen, PopenConfig, Redirection};
use serde::Deserialize;
#[derive(Deserialize)]
struct Details {
finish_reason: String,
generated_tokens: u32,
tokens: Vec<(u32, String, Option<f32>)>,
}
#[derive(Deserialize)]
struct GeneratedText {
generated_text: String,
details: Details,
}
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
"text-generation-launcher".to_string(),
"--model-name".to_string(),
model_name.clone(),
"--num-shard".to_string(),
num_shard.to_string(),
"--port".to_string(),
port.to_string(),
"--master-port".to_string(),
master_port.to_string(),
"--shard-uds-path".to_string(),
format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
];
let mut launcher = Popen::create(
&argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
..Default::default()
},
)
.expect("Could not start launcher");
// Redirect STDOUT and STDERR to the console
let launcher_stdout = launcher.stdout.take().unwrap();
let launcher_stderr = launcher.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(launcher_stdout);
let stderr = BufReader::new(launcher_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
for _ in 0..30 {
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
if health.is_ok() {
return launcher;
}
sleep(Duration::from_secs(2));
}
launcher.terminate().unwrap();
launcher.wait().unwrap();
panic!("failed to launch {}", model_name)
}
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
let data = r#"
{
"inputs": "Test request",
"parameters": {
"details": true
}
}"#;
let req: Value = serde_json::from_str(data).unwrap();
let client = reqwest::blocking::Client::new();
let res = client
.post(format!("http://localhost:{}/generate", port))
.json(&req)
.send();
launcher.terminate().unwrap();
launcher.wait().unwrap();
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
results.pop().unwrap()
}
fn read_json(name: &str) -> GeneratedText {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("tests/");
d.push(name);
let file = File::open(d).unwrap();
let reader = BufReader::new(file);
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
results.pop().unwrap()
}
fn compare_results(result: GeneratedText, expected: GeneratedText) {
assert_eq!(result.generated_text, expected.generated_text);
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
assert_eq!(token.0, expected_token.0);
assert_eq!(token.1, expected_token.1);
if let Some(logprob) = token.2 {
let expected_logprob = expected_token.2.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else {
assert_eq!(token.2, expected_token.2);
}
}
}
#[test]
fn test_bloom_560m() {
let expected = read_json("bloom_560m.json");
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
compare_results(result, expected);
}
#[test]
fn test_bloom_560m_distributed() {
let expected = read_json("bloom_560m.json");
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
compare_results(result, expected);
}
#[test]
fn test_mt0_base() {
let expected = read_json("mt0_base.json");
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
compare_results(result, expected);
}

View File

@ -0,0 +1,116 @@
[
{
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"tokens": [
[
0,
"<pad>",
null
],
[
259,
"",
-1.3656927
],
[
215100,
"\"\"\"",
-2.6551573
],
[
46138,
"Test",
-1.8059857
],
[
287,
"the",
-1.2102449
],
[
259,
"",
-1.6057279
],
[
49076,
"contents",
-3.6060903
],
[
304,
"of",
-0.5270343
],
[
287,
"the",
-0.62522805
],
[
259,
"",
-1.4069618
],
[
49076,
"contents",
-2.621994
],
[
304,
"of",
-1.3172221
],
[
287,
"the",
-0.3501925
],
[
259,
"",
-0.7219573
],
[
49076,
"contents",
-1.0494149
],
[
260,
".",
-1.0803378
],
[
259,
"",
-0.32933083
],
[
215100,
"\"\"\"",
-0.11268901
],
[
2978,
"test",
-1.5846587
],
[
290,
"_",
-0.49796978
],
[
4125,
"test",
-2.0026445
]
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
]