From 3e2e6240b8f557e4db104fbb4e05dc43456b0b48 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 16 Dec 2022 11:29:36 +0100 Subject: [PATCH] feat(launcher): Add integration tests (#9) --- .../{server-tests.yaml => tests.yaml} | 17 +- Cargo.lock | 30 ++-- launcher/Cargo.toml | 8 +- launcher/tests/bloom_560m.json | 121 ++++++++++++++ launcher/tests/integration_tests.rs | 156 ++++++++++++++++++ launcher/tests/mt0_base.json | 116 +++++++++++++ 6 files changed, 434 insertions(+), 14 deletions(-) rename .github/workflows/{server-tests.yaml => tests.yaml} (60%) create mode 100644 launcher/tests/bloom_560m.json create mode 100644 launcher/tests/integration_tests.rs create mode 100644 launcher/tests/mt0_base.json diff --git a/.github/workflows/server-tests.yaml b/.github/workflows/tests.yaml similarity index 60% rename from .github/workflows/server-tests.yaml rename to .github/workflows/tests.yaml index 5bb4653a..7bdd3b73 100644 --- a/.github/workflows/server-tests.yaml +++ b/.github/workflows/tests.yaml @@ -5,6 +5,8 @@ on: paths: - "server/**" - "proto/**" + - "router/**" + - "launcher/**" jobs: run_tests: @@ -15,16 +17,25 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.9 + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: 1.65.0 + override: true + components: rustfmt, clippy - name: Loading cache. uses: actions/cache@v2 id: model_cache with: path: ~/.cache/huggingface/ key: models - - name: Install server dependencies + - name: Install run: | - make install-server - - name: Run tests + make install + - name: Run server tests run: | pip install pytest pytest -sv server/tests + - name: Run Rust tests + run: | + cargo test diff --git a/Cargo.lock b/Cargo.lock index 0fd5c4bf..752c4886 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,6 +543,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float_eq" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853" + [[package]] name = "fnv" version = "1.0.7" @@ -1505,9 +1511,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.12" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "431949c384f4e2ae07605ccaa56d1d9d2ecdb5cadd4f9577ccfab29f2e5149fc" +checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c" dependencies = [ "base64", "bytes", @@ -1587,18 +1593,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.147" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965" +checksum = "e326c9ec8042f1b5da33252c8a37e9ffbd2c9bef0155215b6e6c80c790e05f91" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.147" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852" +checksum = "42a3df25b0713732468deadad63ab9da1f1fd75a48a15024b50363f128db627e" dependencies = [ "proc-macro2", "quote", @@ -1607,9 +1613,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.87" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" +checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db" dependencies = [ "itoa", "ryu", @@ -1724,9 +1730,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.103" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" +checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908" dependencies = [ "proc-macro2", "quote", @@ -1804,6 +1810,10 @@ version = "0.1.0" dependencies = [ "clap 4.0.22", "ctrlc", + "float_eq", + "reqwest", + "serde", + "serde_json", "subprocess", "tracing", "tracing-subscriber", diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 1779c051..ecdef831 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -7,7 +7,13 @@ description = "Text Generation Launcher" [dependencies] clap = { version = "4.0.15", features = ["derive", "env"] } -ctrlc = "3.2.3" +ctrlc = { version = "3.2.3", features = ["termination"] } subprocess = "0.2.9" tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json"] } + +[dev-dependencies] +float_eq = "1.0.1" +reqwest = { version = "0.11.13", features = ["blocking", "json"] } +serde = "1.0.150" +serde_json = "1.0.89" diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json new file mode 100644 index 00000000..d17f1ed4 --- /dev/null +++ b/launcher/tests/bloom_560m.json @@ -0,0 +1,121 @@ +[ + { + "details": { + "finish_reason": "length", + "generated_tokens": 20, + "tokens": [ + [ + 10264, + "Test", + null + ], + [ + 8821, + " request", + -11.895094 + ], + [ + 17, + ".", + -1.8267941 + ], + [ + 1587, + "get", + -2.4674964 + ], + [ + 11, + "(", + -1.9060438 + ], + [ + 5, + "\"", + -1.2279553 + ], + [ + 4899, + "action", + -4.170306 + ], + [ + 5, + "\"", + -0.3247902 + ], + [ + 12, + ")", + -1.0773602 + ], + [ + 30, + ";", + -0.27640444 + ], + [ + 837, + "\n ", + -1.6970599 + ], + [ + 1320, + " if", + -1.4495552 + ], + [ + 375, + " (", + -0.2360998 + ], + [ + 4899, + "action", + -1.1916926 + ], + [ + 3535, + " ==", + -0.8918663 + ], + [ + 5109, + " null", + -0.39334255 + ], + [ + 12, + ")", + -0.4321134 + ], + [ + 731, + " {", + -0.17701954 + ], + [ + 1260, + "\n ", + -0.07027287 + ], + [ + 10519, + " throw", + -1.3915133 + ], + [ + 2084, + " new", + -0.042013377 + ], + [ + 150858, + " RuntimeException", + -1.7330077 + ] + ] + }, + "generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" + } +] \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs new file mode 100644 index 00000000..3e68f6be --- /dev/null +++ b/launcher/tests/integration_tests.rs @@ -0,0 +1,156 @@ +use std::fs::File; +use serde_json::Value; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; +use std::thread; +use std::thread::sleep; +use std::time::Duration; +use float_eq::assert_float_eq; +use subprocess::{Popen, PopenConfig, Redirection}; +use serde::Deserialize; + +#[derive(Deserialize)] +struct Details { + finish_reason: String, + generated_tokens: u32, + tokens: Vec<(u32, String, Option)>, +} + +#[derive(Deserialize)] +struct GeneratedText { + generated_text: String, + details: Details, +} + + +fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { + let argv = vec![ + "text-generation-launcher".to_string(), + "--model-name".to_string(), + model_name.clone(), + "--num-shard".to_string(), + num_shard.to_string(), + "--port".to_string(), + port.to_string(), + "--master-port".to_string(), + master_port.to_string(), + "--shard-uds-path".to_string(), + format!("/tmp/test-{}-{}-{}", num_shard, port, master_port), + ]; + + let mut launcher = Popen::create( + &argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + ..Default::default() + }, + ) + .expect("Could not start launcher"); + + // Redirect STDOUT and STDERR to the console + let launcher_stdout = launcher.stdout.take().unwrap(); + let launcher_stderr = launcher.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(launcher_stdout); + let stderr = BufReader::new(launcher_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + + for _ in 0..30 { + let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); + if health.is_ok() { + return launcher; + } + sleep(Duration::from_secs(2)); + } + + launcher.terminate().unwrap(); + launcher.wait().unwrap(); + panic!("failed to launch {}", model_name) +} + +fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText { + let mut launcher = start_launcher(model_name, num_shard, port, master_port); + + let data = r#" + { + "inputs": "Test request", + "parameters": { + "details": true + } + }"#; + let req: Value = serde_json::from_str(data).unwrap(); + + let client = reqwest::blocking::Client::new(); + let res = client + .post(format!("http://localhost:{}/generate", port)) + .json(&req) + .send(); + + launcher.terminate().unwrap(); + launcher.wait().unwrap(); + + let mut results: Vec = res.unwrap().json().unwrap(); + results.pop().unwrap() +} + + +fn read_json(name: &str) -> GeneratedText { + let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + d.push("tests/"); + d.push(name); + + let file = File::open(d).unwrap(); + let reader = BufReader::new(file); + + let mut results: Vec = serde_json::from_reader(reader).unwrap(); + results.pop().unwrap() +} + +fn compare_results(result: GeneratedText, expected: GeneratedText) { + assert_eq!(result.generated_text, expected.generated_text); + assert_eq!(result.details.finish_reason, expected.details.finish_reason); + assert_eq!(result.details.generated_tokens, expected.details.generated_tokens); + + for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) { + assert_eq!(token.0, expected_token.0); + assert_eq!(token.1, expected_token.1); + if let Some(logprob) = token.2 { + let expected_logprob = expected_token.2.unwrap(); + assert_float_eq!(logprob, expected_logprob, abs <= 0.001); + } else { + assert_eq!(token.2, expected_token.2); + } + } +} + +#[test] +fn test_bloom_560m() { + let expected = read_json("bloom_560m.json"); + + let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500); + compare_results(result, expected); +} + +#[test] +fn test_bloom_560m_distributed() { + let expected = read_json("bloom_560m.json"); + + let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501); + compare_results(result, expected); +} + +#[test] +fn test_mt0_base() { + let expected = read_json("mt0_base.json"); + + let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502); + compare_results(result, expected); +} diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json new file mode 100644 index 00000000..1b772282 --- /dev/null +++ b/launcher/tests/mt0_base.json @@ -0,0 +1,116 @@ +[ + { + "details": { + "finish_reason": "length", + "generated_tokens": 20, + "tokens": [ + [ + 0, + "", + null + ], + [ + 259, + "", + -1.3656927 + ], + [ + 215100, + "\"\"\"", + -2.6551573 + ], + [ + 46138, + "Test", + -1.8059857 + ], + [ + 287, + "the", + -1.2102449 + ], + [ + 259, + "", + -1.6057279 + ], + [ + 49076, + "contents", + -3.6060903 + ], + [ + 304, + "of", + -0.5270343 + ], + [ + 287, + "the", + -0.62522805 + ], + [ + 259, + "", + -1.4069618 + ], + [ + 49076, + "contents", + -2.621994 + ], + [ + 304, + "of", + -1.3172221 + ], + [ + 287, + "the", + -0.3501925 + ], + [ + 259, + "", + -0.7219573 + ], + [ + 49076, + "contents", + -1.0494149 + ], + [ + 260, + ".", + -1.0803378 + ], + [ + 259, + "", + -0.32933083 + ], + [ + 215100, + "\"\"\"", + -0.11268901 + ], + [ + 2978, + "test", + -1.5846587 + ], + [ + 290, + "_", + -0.49796978 + ], + [ + 4125, + "test", + -2.0026445 + ] + ] + }, + "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" + } +] \ No newline at end of file