breaking(router): modify /generate API to only return generated text (#50)
@njhill, @yk FYI generated_text was concatenated to the user prompt for legacy reason. We want to remove this behaviour as we don't think it is useful and even detrimonial to usability. We also remove the unused Vec.
This commit is contained in:
parent
7b870e1e18
commit
b1482d9048
|
@ -118,6 +118,6 @@
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||||
}
|
}
|
||||||
]
|
]
|
|
@ -97,8 +97,8 @@ fn test_model(
|
||||||
launcher.terminate().unwrap();
|
launcher.terminate().unwrap();
|
||||||
launcher.wait().unwrap();
|
launcher.wait().unwrap();
|
||||||
|
|
||||||
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
|
let result: GeneratedText = res.unwrap().json().unwrap();
|
||||||
results.pop().unwrap()
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_json(name: &str) -> GeneratedText {
|
fn read_json(name: &str) -> GeneratedText {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod infer;
|
|
||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
|
|
||||||
|
mod infer;
|
||||||
mod queue;
|
mod queue;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
|
@ -125,10 +125,10 @@ async fn generate(
|
||||||
tracing::info!("Output: {}", response.generated_text.text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GenerateResponse {
|
let response = GenerateResponse {
|
||||||
generated_text: response.generated_text.text,
|
generated_text: response.generated_text.text,
|
||||||
details,
|
details,
|
||||||
}];
|
};
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text
|
||||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
== "TestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
|
@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
|
assert generations[1].generated_text.text == "TestTestTestTestTest"
|
||||||
assert (
|
assert (
|
||||||
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||||
)
|
)
|
||||||
|
@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text
|
||||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
== "TestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
@ -261,7 +261,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 3
|
assert len(generations) == 3
|
||||||
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
|
assert generations[2].generated_text.text == "TestTestTestTestTest"
|
||||||
assert (
|
assert (
|
||||||
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||||
)
|
)
|
||||||
|
@ -284,7 +284,7 @@ def test_batch_concatenate(
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text
|
||||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
== "TestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
|
@ -307,7 +307,7 @@ def test_batch_concatenate(
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text
|
||||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
== "TestTestTestTestTestTestTestTestTestTest"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
|
|
@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
|
@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
assert generations[1].generated_text.text == "Test.java:784)"
|
assert generations[1].generated_text.text == ".java:784)"
|
||||||
assert (
|
assert (
|
||||||
generations[1].request_id
|
generations[1].request_id
|
||||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||||
|
@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||||
assert (
|
assert (
|
||||||
generations[0].request_id
|
generations[0].request_id
|
||||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||||
|
@ -255,7 +255,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 3
|
assert len(generations) == 3
|
||||||
assert generations[2].generated_text.text == "Test.java:784)"
|
assert generations[2].generated_text.text == ".java:784)"
|
||||||
assert (
|
assert (
|
||||||
generations[2].request_id
|
generations[2].request_id
|
||||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||||
|
@ -277,7 +277,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is not None
|
assert next_batch is not None
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
|
@ -297,7 +297,7 @@ def test_batch_concatenate(
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||||
assert (
|
assert (
|
||||||
generations[0].request_id
|
generations[0].request_id
|
||||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||||
|
|
|
@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
||||||
assert next_batch is None
|
assert next_batch is None
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert generations[0].generated_text.text == "def test_get_all_users_with_"
|
assert generations[0].generated_text.text == " test_get_all_users_with_"
|
||||||
assert generations[0].request_id == batch.requests[0].id
|
assert generations[0].request_id == batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.generated_tokens
|
generations[0].generated_text.generated_tokens
|
||||||
|
@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text
|
||||||
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
== """ineProperty(exports, "__esModule", { value"""
|
||||||
)
|
)
|
||||||
assert generations[0].request_id == batch.requests[0].id
|
assert generations[0].request_id == batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
|
||||||
def get_model(
|
def get_model(
|
||||||
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
config = AutoConfig.from_pretrained(model_name, revision=revision)
|
||||||
|
|
||||||
if config.model_type == "bloom":
|
if config.model_type == "bloom":
|
||||||
if sharded:
|
if sharded:
|
||||||
|
|
|
@ -360,11 +360,9 @@ class CausalLM(Model):
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
generated_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
)
|
)
|
||||||
output_text = request.inputs + generated_text
|
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
seed = next_token_chooser.choice.seed
|
seed = next_token_chooser.choice.seed
|
||||||
|
|
Loading…
Reference in New Issue