This commit is contained in:
Nicolas Patry 2023-09-11 18:25:49 +00:00
parent 4cce84301b
commit 33958e0989
7 changed files with 48 additions and 4 deletions

8
Cargo.lock generated
View File

@ -2866,7 +2866,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "1.0.1"
version = "1.0.3"
dependencies = [
"average",
"clap",
@ -2886,7 +2886,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "1.0.1"
version = "1.0.3"
dependencies = [
"futures",
"grpc-metadata",
@ -2902,7 +2902,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "1.0.1"
version = "1.0.3"
dependencies = [
"clap",
"ctrlc",
@ -2918,7 +2918,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "1.0.1"
version = "1.0.3"
dependencies = [
"async-stream",
"axum",

View File

@ -336,6 +336,10 @@ struct Args {
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
/// Use speculation on a given model_id
#[clap(long, short)]
speculate_model_id: Option<String>,
}
#[derive(Debug)]

View File

@ -17,6 +17,8 @@ service TextGenerationService {
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Add to speculative ids to the given requests
rpc Speculate (SpeculateRequest) returns (SpeculateResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
}
@ -93,6 +95,17 @@ message Request {
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// The speculative generation
optional string speculate = 8;
}
message Speculate {
/// Request ID to speculate on
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Context truncation
string speculation = 3;
}
message Batch {
@ -210,6 +223,17 @@ message DecodeResponse {
optional CachedBatch batch = 2;
}
message SpeculateRequest {
/// Cached batches
repeated Speculate speculations = 1;
}
message SpeculateResponse {
// Next batch (cached)
// optional CachedBatch batch = 2;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;

View File

@ -132,6 +132,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
speculate: None
});
n_tokens += max_input_length;
}
@ -162,6 +163,18 @@ impl Client {
Ok((response.generations, response.batch))
}
/// Add speculation proposal to existing requests
///
#[instrument(skip_all)]
pub async fn speculate(
&mut self,
speculations: Vec<Speculate>,
) -> Result<()> {
let request = tonic::Request::new(SpeculateRequest { speculations }).inject_context();
let _response = self.stub.speculate(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches

1
router/req.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -51,6 +51,7 @@ impl Health {
ignore_eos_token: false,
}),
top_n_tokens: 0,
speculate: None
};
let batch = Batch {
id: BATCH_ID,

View File

@ -236,6 +236,7 @@ impl State {
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens,
speculate: None
});
// Set batch_time
entry.batch_time = Some(Instant::now());