breaking(router): modify /generate API to only return generated text (#50)
@njhill, @yk FYI generated_text was concatenated to the user prompt for legacy reason. We want to remove this behaviour as we don't think it is useful and even detrimonial to usability. We also remove the unused Vec.
This commit is contained in:
parent
7b870e1e18
commit
b1482d9048
|
@ -118,6 +118,6 @@
|
|||
]
|
||||
]
|
||||
},
|
||||
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||
}
|
||||
]
|
|
@ -97,8 +97,8 @@ fn test_model(
|
|||
launcher.terminate().unwrap();
|
||||
launcher.wait().unwrap();
|
||||
|
||||
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
|
||||
results.pop().unwrap()
|
||||
let result: GeneratedText = res.unwrap().json().unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
fn read_json(name: &str) -> GeneratedText {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
mod infer;
|
||||
/// Text Generation Inference Webserver
|
||||
|
||||
mod infer;
|
||||
mod queue;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
|
|
|
@ -125,10 +125,10 @@ async fn generate(
|
|||
tracing::info!("Output: {}", response.generated_text.text);
|
||||
|
||||
// Send response
|
||||
let response = vec![GenerateResponse {
|
||||
let response = GenerateResponse {
|
||||
generated_text: response.generated_text.text,
|
||||
details,
|
||||
}];
|
||||
};
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
|||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
|
@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
|
||||
assert generations[1].generated_text.text == "TestTestTestTestTest"
|
||||
assert (
|
||||
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||
)
|
||||
|
@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert (
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
|
@ -261,7 +261,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
|
||||
assert generations[2].generated_text.text == "TestTestTestTestTest"
|
||||
assert (
|
||||
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||
)
|
||||
|
@ -284,7 +284,7 @@ def test_batch_concatenate(
|
|||
assert len(generations) == 2
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
|
@ -307,7 +307,7 @@ def test_batch_concatenate(
|
|||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert (
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
|
|
|
@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
|
@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "Test.java:784)"
|
||||
assert generations[1].generated_text.text == ".java:784)"
|
||||
assert (
|
||||
generations[1].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||
|
@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert (
|
||||
generations[0].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||
|
@ -255,7 +255,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "Test.java:784)"
|
||||
assert generations[2].generated_text.text == ".java:784)"
|
||||
assert (
|
||||
generations[2].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||
|
@ -277,7 +277,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
|
@ -297,7 +297,7 @@ def test_batch_concatenate(
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert (
|
||||
generations[0].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
|||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "def test_get_all_users_with_"
|
||||
assert generations[0].generated_text.text == " test_get_all_users_with_"
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
|
@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
|
|||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
||||
== """ineProperty(exports, "__esModule", { value"""
|
||||
)
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
|
|
|
@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
|
|||
def get_model(
|
||||
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
) -> Model:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name, revision=revision)
|
||||
|
||||
if config.model_type == "bloom":
|
||||
if sharded:
|
||||
|
|
|
@ -360,11 +360,9 @@ class CausalLM(Model):
|
|||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
generated_text = self.decode(
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
output_text = request.inputs + generated_text
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
|
|
Loading…
Reference in New Issue