diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index a81d1982..3a0a3d99 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -118,6 +118,6 @@ ] ] }, - "generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" + "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" } ] \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index 6b270659..4f699f69 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -97,8 +97,8 @@ fn test_model( launcher.terminate().unwrap(); launcher.wait().unwrap(); - let mut results: Vec = res.unwrap().json().unwrap(); - results.pop().unwrap() + let result: GeneratedText = res.unwrap().json().unwrap(); + result } fn read_json(name: &str) -> GeneratedText { diff --git a/router/src/lib.rs b/router/src/lib.rs index c6ac2022..542a8f78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,5 +1,6 @@ -mod infer; /// Text Generation Inference Webserver + +mod infer; mod queue; pub mod server; mod validation; diff --git a/router/src/server.rs b/router/src/server.rs index c31ca6ce..f79644e3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -125,10 +125,10 @@ async fn generate( tracing::info!("Output: {}", response.generated_text.text); // Send response - let response = vec![GenerateResponse { + let response = GenerateResponse { generated_text: response.generated_text.text, details, - }]; + }; Ok((headers, Json(response))) } diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 9f96efc3..871c0da0 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert len(generations) == 1 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( @@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generations) == 2 - assert generations[1].generated_text.text == "TestTestTestTestTestTest" + assert generations[1].generated_text.text == "TestTestTestTestTest" assert ( generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id ) @@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi( assert len(generations) == 1 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert ( generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id @@ -261,7 +261,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generations) == 3 - assert generations[2].generated_text.text == "TestTestTestTestTestTest" + assert generations[2].generated_text.text == "TestTestTestTestTest" assert ( generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id ) @@ -284,7 +284,7 @@ def test_batch_concatenate( assert len(generations) == 2 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( @@ -307,7 +307,7 @@ def test_batch_concatenate( assert len(generations) == 1 assert ( generations[0].generated_text.text - == "TestTestTestTestTestTestTestTestTestTestTest" + == "TestTestTestTestTestTestTestTestTestTest" ) assert ( generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f9762b30..6a822815 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( generations[0].generated_text.generated_tokens @@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generations) == 2 - assert generations[1].generated_text.text == "Test.java:784)" + assert generations[1].generated_text.text == ".java:784)" assert ( generations[1].request_id == default_multi_requests_causal_lm_batch.requests[1].id @@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( generations[0].request_id == default_multi_requests_causal_lm_batch.requests[0].id @@ -255,7 +255,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generations) == 3 - assert generations[2].generated_text.text == "Test.java:784)" + assert generations[2].generated_text.text == ".java:784)" assert ( generations[2].request_id == default_multi_requests_causal_lm_batch.requests[1].id @@ -277,7 +277,7 @@ def test_batch_concatenate( assert next_batch is not None assert len(generations) == 2 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( generations[0].generated_text.generated_tokens @@ -297,7 +297,7 @@ def test_batch_concatenate( assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( generations[0].request_id == default_multi_requests_causal_lm_batch.requests[0].id diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1b69477d..1596e413 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat assert next_batch is None assert len(generations) == 1 - assert generations[0].generated_text.text == "def test_get_all_users_with_" + assert generations[0].generated_text.text == " test_get_all_users_with_" assert generations[0].request_id == batch.requests[0].id assert ( generations[0].generated_text.generated_tokens @@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion( assert len(generations) == 1 assert ( generations[0].generated_text.text - == """defworldineProperty(exports, "__esModule", { value""" + == """ineProperty(exports, "__esModule", { value""" ) assert generations[0].request_id == batch.requests[0].id assert ( diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 9309c887..15d8e97e 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True def get_model( model_name: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, revision=revision) if config.model_type == "bloom": if sharded: diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 994c57d5..1d1945cd 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -360,11 +360,9 @@ class CausalLM(Model): if stop: # Decode generated tokens - generated_text = self.decode( + output_text = self.decode( all_input_ids[-stopping_criteria.current_tokens :, 0] ) - output_text = request.inputs + generated_text - # Get seed if isinstance(next_token_chooser.choice, Sampling): seed = next_token_chooser.choice.seed