breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.
This commit is contained in:
OlivierDehaene 2023-02-02 15:02:04 +01:00 committed by GitHub
parent 7b870e1e18
commit b1482d9048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 23 additions and 24 deletions

View File

@ -118,6 +118,6 @@
] ]
] ]
}, },
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }
] ]

View File

@ -97,8 +97,8 @@ fn test_model(
launcher.terminate().unwrap(); launcher.terminate().unwrap();
launcher.wait().unwrap(); launcher.wait().unwrap();
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap(); let result: GeneratedText = res.unwrap().json().unwrap();
results.pop().unwrap() result
} }
fn read_json(name: &str) -> GeneratedText { fn read_json(name: &str) -> GeneratedText {

View File

@ -1,5 +1,6 @@
mod infer;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer;
mod queue; mod queue;
pub mod server; pub mod server;
mod validation; mod validation;

View File

@ -125,10 +125,10 @@ async fn generate(
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Send response // Send response
let response = vec![GenerateResponse { let response = GenerateResponse {
generated_text: response.generated_text.text, generated_text: response.generated_text.text,
details, details,
}]; };
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }

View File

@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTestTest" assert generations[1].generated_text.text == "TestTestTestTestTest"
assert ( assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
) )
@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@ -261,7 +261,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTestTest" assert generations[2].generated_text.text == "TestTestTestTestTest"
assert ( assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
) )
@ -284,7 +284,7 @@ def test_batch_concatenate(
assert len(generations) == 2 assert len(generations) == 2
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -307,7 +307,7 @@ def test_batch_concatenate(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

View File

@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[1].generated_text.text == "Test.java:784)" assert generations[1].generated_text.text == ".java:784)"
assert ( assert (
generations[1].request_id generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id == default_multi_requests_causal_lm_batch.requests[1].id
@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert ( assert (
generations[0].request_id generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id == default_multi_requests_causal_lm_batch.requests[0].id
@ -255,7 +255,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
assert generations[2].generated_text.text == "Test.java:784)" assert generations[2].generated_text.text == ".java:784)"
assert ( assert (
generations[2].request_id generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id == default_multi_requests_causal_lm_batch.requests[1].id
@ -277,7 +277,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
@ -297,7 +297,7 @@ def test_batch_concatenate(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert ( assert (
generations[0].request_id generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id == default_multi_requests_causal_lm_batch.requests[0].id

View File

@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "def test_get_all_users_with_" assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id assert generations[0].request_id == batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value""" == """ineProperty(exports, "__esModule", { value"""
) )
assert generations[0].request_id == batch.requests[0].id assert generations[0].request_id == batch.requests[0].id
assert ( assert (

View File

@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
def get_model( def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool model_name: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
if sharded: if sharded:

View File

@ -360,11 +360,9 @@ class CausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
generated_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0] all_input_ids[-stopping_criteria.current_tokens :, 0]
) )
output_text = request.inputs + generated_text
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed seed = next_token_chooser.choice.seed