breaking(router): modify /generate API to only return generated text (#50)

@njhill, @yk FYI

generated_text was concatenated to the user prompt for legacy reason. We
want to remove this behaviour as we don't think it is useful and even
detrimonial to usability.

We also remove the unused Vec.
This commit is contained in:
OlivierDehaene 2023-02-02 15:02:04 +01:00 committed by GitHub
parent 7b870e1e18
commit b1482d9048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 23 additions and 24 deletions

View File

@ -118,6 +118,6 @@
]
]
},
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
]

View File

@ -97,8 +97,8 @@ fn test_model(
launcher.terminate().unwrap();
launcher.wait().unwrap();
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
results.pop().unwrap()
let result: GeneratedText = res.unwrap().json().unwrap();
result
}
fn read_json(name: &str) -> GeneratedText {

View File

@ -1,5 +1,6 @@
mod infer;
/// Text Generation Inference Webserver
mod infer;
mod queue;
pub mod server;
mod validation;

View File

@ -125,10 +125,10 @@ async fn generate(
tracing::info!("Output: {}", response.generated_text.text);
// Send response
let response = vec![GenerateResponse {
let response = GenerateResponse {
generated_text: response.generated_text.text,
details,
}];
};
Ok((headers, Json(response)))
}

View File

@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
assert generations[1].generated_text.text == "TestTestTestTestTest"
assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
)
@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@ -261,7 +261,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
assert generations[2].generated_text.text == "TestTestTestTestTest"
assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
)
@ -284,7 +284,7 @@ def test_batch_concatenate(
assert len(generations) == 2
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert generations[0].request_id == default_bloom_batch.requests[0].id
assert (
@ -307,7 +307,7 @@ def test_batch_concatenate(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
)
assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

View File

@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None
assert len(generations) == 2
assert generations[1].generated_text.text == "Test.java:784)"
assert generations[1].generated_text.text == ".java:784)"
assert (
generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id
@ -255,7 +255,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generations) == 3
assert generations[2].generated_text.text == "Test.java:784)"
assert generations[2].generated_text.text == ".java:784)"
assert (
generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id
@ -277,7 +277,7 @@ def test_batch_concatenate(
assert next_batch is not None
assert len(generations) == 2
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
@ -297,7 +297,7 @@ def test_batch_concatenate(
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert (
generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id

View File

@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
assert next_batch is None
assert len(generations) == 1
assert generations[0].generated_text.text == "def test_get_all_users_with_"
assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id
assert (
generations[0].generated_text.generated_tokens
@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
assert len(generations) == 1
assert (
generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
== """ineProperty(exports, "__esModule", { value"""
)
assert generations[0].request_id == batch.requests[0].id
assert (

View File

@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
config = AutoConfig.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, revision=revision)
if config.model_type == "bloom":
if sharded:

View File

@ -360,11 +360,9 @@ class CausalLM(Model):
if stop:
# Decode generated tokens
generated_text = self.decode(
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
output_text = request.inputs + generated_text
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed