From 81707bfbfa5af7465f45c466e61d421de889e8ef Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Jun 2024 23:23:17 +0000 Subject: [PATCH] fix: include rust code for adapter id --- benchmark/src/generation.rs | 2 +- proto/v3/generate.proto | 2 +- router/client/src/v3/client.rs | 2 +- router/client/src/v3/sharded_client.rs | 2 +- router/src/infer/v2/queue.rs | 2 +- router/src/infer/v3/queue.rs | 4 ++-- router/src/lib.rs | 4 ++-- router/src/validation.rs | 6 +++--- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index e00437c7..5e739703 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,7 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], - adapter_index: None, + adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index c7a8013b..926c878e 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -135,7 +135,7 @@ message Request { /// Paged attention slots repeated uint32 slots = 10; /// LORA adapter index - optional uint32 adapter_index = 11; + optional string adapter_id = 11; } message Batch { diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 5ced4056..a996b14f 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -177,7 +177,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, - adapter_index: None, + adapter_id: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 300decca..ae8a899b 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -244,7 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), - adapter_index: None, + adapter_id: None, }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index f0205697..93cf9469 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -429,7 +429,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, - adapter_index: None, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index fbfdf715..ba65b9b6 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -351,7 +351,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, - adapter_index: entry.request.adapter_index, + adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -492,7 +492,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, - adapter_index: None, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/lib.rs b/router/src/lib.rs index c99e8281..08d57873 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters { /// Lora adapter id #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] - pub adapter_index: Option, + pub adapter_id: Option, } fn default_max_new_tokens() -> Option { @@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters { seed: None, top_n_tokens: None, grammar: None, - adapter_index: None, + adapter_id: None, } } diff --git a/router/src/validation.rs b/router/src/validation.rs index 6f776870..e2bf5a5d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -202,7 +202,7 @@ impl Validation { decoder_input_details, top_n_tokens, grammar, - adapter_index, + adapter_id, .. } = request.parameters; @@ -384,7 +384,7 @@ impl Validation { parameters, stopping_parameters, top_n_tokens, - adapter_index, + adapter_id, }) } @@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, - pub adapter_index: Option, + pub adapter_id: Option, } #[derive(Error, Debug)]