fix: include rust code for adapter id

This commit is contained in:
drbh 2024-06-06 23:23:17 +00:00
parent 68399c1ae3
commit 81707bfbfa
8 changed files with 12 additions and 12 deletions

View File

@ -157,7 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
adapter_index: None, adapter_id: None,
}) })
.collect(); .collect();

View File

@ -135,7 +135,7 @@ message Request {
/// Paged attention slots /// Paged attention slots
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional uint32 adapter_index = 11; optional string adapter_id = 11;
} }
message Batch { message Batch {

View File

@ -177,7 +177,7 @@ impl Client {
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20, top_n_tokens: 20,
adapter_index: None, adapter_id: None,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;

View File

@ -244,7 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
adapter_index: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,

View File

@ -429,7 +429,7 @@ mod tests {
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0, top_n_tokens: 0,
adapter_index: None, adapter_id: None,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),

View File

@ -351,7 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
adapter_index: entry.request.adapter_index, adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -492,7 +492,7 @@ mod tests {
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0, top_n_tokens: 0,
adapter_index: None, adapter_id: None,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),

View File

@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters {
/// Lora adapter id /// Lora adapter id
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub adapter_index: Option<u32>, pub adapter_id: Option<String>,
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters {
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
adapter_index: None, adapter_id: None,
} }
} }

View File

@ -202,7 +202,7 @@ impl Validation {
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
grammar, grammar,
adapter_index, adapter_id,
.. ..
} = request.parameters; } = request.parameters;
@ -384,7 +384,7 @@ impl Validation {
parameters, parameters,
stopping_parameters, stopping_parameters,
top_n_tokens, top_n_tokens,
adapter_index, adapter_id,
}) })
} }
@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters, pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters, pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32, pub top_n_tokens: u32,
pub adapter_index: Option<u32>, pub adapter_id: Option<String>,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]