This commit is contained in:
Nicolas Patry 2023-09-11 18:25:49 +00:00
parent 4cce84301b
commit 33958e0989
7 changed files with 48 additions and 4 deletions

8
Cargo.lock generated
View File

@ -2866,7 +2866,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "1.0.1" version = "1.0.3"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2886,7 +2886,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "1.0.1" version = "1.0.3"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2902,7 +2902,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "1.0.1" version = "1.0.3"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2918,7 +2918,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "1.0.1" version = "1.0.3"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",

View File

@ -336,6 +336,10 @@ struct Args {
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
/// Use speculation on a given model_id
#[clap(long, short)]
speculate_model_id: Option<String>,
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -17,6 +17,8 @@ service TextGenerationService {
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse); rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Add to speculative ids to the given requests
rpc Speculate (SpeculateRequest) returns (SpeculateResponse);
/// Health check /// Health check
rpc Health (HealthRequest) returns (HealthResponse); rpc Health (HealthRequest) returns (HealthResponse);
} }
@ -93,6 +95,17 @@ message Request {
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// The speculative generation
optional string speculate = 8;
}
message Speculate {
/// Request ID to speculate on
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Context truncation
string speculation = 3;
} }
message Batch { message Batch {
@ -210,6 +223,17 @@ message DecodeResponse {
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
} }
message SpeculateRequest {
/// Cached batches
repeated Speculate speculations = 1;
}
message SpeculateResponse {
// Next batch (cached)
// optional CachedBatch batch = 2;
}
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;

View File

@ -132,6 +132,7 @@ impl Client {
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,
speculate: None
}); });
n_tokens += max_input_length; n_tokens += max_input_length;
} }
@ -162,6 +163,18 @@ impl Client {
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
} }
/// Add speculation proposal to existing requests
///
#[instrument(skip_all)]
pub async fn speculate(
&mut self,
speculations: Vec<Speculate>,
) -> Result<()> {
let request = tonic::Request::new(SpeculateRequest { speculations }).inject_context();
let _response = self.stub.speculate(request).await?.into_inner();
Ok(())
}
/// Generate one token for each request in the given cached batches /// Generate one token for each request in the given cached batches
/// ///
/// Returns Generation for each request in batches /// Returns Generation for each request in batches

1
router/req.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -51,6 +51,7 @@ impl Health {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0, top_n_tokens: 0,
speculate: None
}; };
let batch = Batch { let batch = Batch {
id: BATCH_ID, id: BATCH_ID,

View File

@ -236,6 +236,7 @@ impl State {
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
speculate: None
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());