From b1482d90488b0e92a0adc06d47a603b5d837442c Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 2 Feb 2023 15:02:04 +0100 Subject: [PATCH] breaking(router): modify /generate API to only return generated text (#50) @njhill, @yk FYI generated_text was concatenated to the user prompt for legacy reason. We want to remove this behaviour as we don't think it is useful and even detrimonial to usability. We also remove the unused Vec. --- launcher/tests/bloom_560m.json | 2 +- launcher/tests/integration_tests.rs | 4 ++-- router/src/lib.rs | 3 ++- router/src/server.rs | 4 ++-- server/tests/models/test_bloom.py | 12 ++++++------ server/tests/models/test_causal_lm.py | 12 ++++++------ server/tests/models/test_santacoder.py | 4 ++-- server/text_generation/models/__init__.py | 2 +- server/text_generation/models/causal_lm.py | 4 +--- 9 files changed, 23 insertions(+), 24 deletions(-) diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index a81d1982..3a0a3d99 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -118,6 +118,6 @@ ] ] }, - "generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" + "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" } ] \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index 6b270659..4f699f69 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -97,8 +97,8 @@ fn test_model( launcher.terminate().unwrap(); launcher.wait().unwrap(); - let mut results: Vec = res.unwrap().json().unwrap(); - results.pop().unwrap() + let result: GeneratedText = res.unwrap().json().unwrap(); + result } fn read_json(name: &str) -> GeneratedText { diff --git a/router/src/lib.rs b/router/src/lib.rs index c6ac2022..542a8f78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,5 +1,6 @@ -mod infer; /// Text Generation Inference Webserver + +mod infer; mod queue; pub mod server; mod validation; diff --git a/router/src/server.rs b/router/src/server.rs index c31ca6ce..f79644e3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -125,10 +125,10 @@ async fn generate( tracing::info!("Output: {}", response.generated_text.text); // Send response - let response = vec![GenerateResponse { + let response = GenerateResponse { generated_text: response.generated_text.text, details, - }]; + }; Ok((headers, Json(response))) } diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 9f96efc3..871c0da0 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert len(generations) == 1 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( @@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generations) == 2 - assert generations[1].generated_text.text == "TestTestTestTestTestTest" + assert generations[1].generated_text.text == "TestTestTestTestTest" assert ( generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id ) @@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi( assert len(generations) == 1 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert ( generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id @@ -261,7 +261,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generations) == 3 - assert generations[2].generated_text.text == "TestTestTestTestTestTest" + assert generations[2].generated_text.text == "TestTestTestTestTest" assert ( generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id ) @@ -284,7 +284,7 @@ def test_batch_concatenate( assert len(generations) == 2 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( @@ -307,7 +307,7 @@ def test_batch_concatenate( assert len(generations) == 1 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert ( generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f9762b30..6a822815 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( generations[0].generated_text.generated_tokens @@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generations) == 2 - assert generations[1].generated_text.text == "Test.java:784)" + assert generations[1].generated_text.text == ".java:784)" assert ( generations[1].request_id == default_multi_requests_causal_lm_batch.requests[1].id @@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( generations[0].request_id == default_multi_requests_causal_lm_batch.requests[0].id @@ -255,7 +255,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generations) == 3 - assert generations[2].generated_text.text == "Test.java:784)" + assert generations[2].generated_text.text == ".java:784)" assert ( generations[2].request_id == default_multi_requests_causal_lm_batch.requests[1].id @@ -277,7 +277,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generations) == 2 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( generations[0].generated_text.generated_tokens @@ -297,7 +297,7 @@ def test_batch_concatenate( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( generations[0].request_id == default_multi_requests_causal_lm_batch.requests[0].id diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1b69477d..1596e413 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "def test_get_all_users_with_" + assert generations[0].generated_text.text == " test_get_all_users_with_" assert generations[0].request_id == batch.requests[0].id assert ( generations[0].generated_text.generated_tokens @@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion( assert len(generations) == 1 assert ( generations[0].generated_text.text - == """defworldineProperty(exports, "__esModule", { value""" + == """ineProperty(exports, "__esModule", { value""" ) assert generations[0].request_id == batch.requests[0].id assert ( diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 9309c887..15d8e97e 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True def get_model( model_name: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, revision=revision) if config.model_type == "bloom": if sharded: diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 994c57d5..1d1945cd 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -360,11 +360,9 @@ class CausalLM(Model): if stop: # Decode generated tokens - generated_text = self.decode( + output_text = self.decode( all_input_ids[-stopping_criteria.current_tokens :, 0] ) - output_text = request.inputs + generated_text - # Get seed if isinstance(next_token_chooser.choice, Sampling): seed = next_token_chooser.choice.seed