fix: include rust code for adapter id

This commit is contained in:
drbh 2024-06-06 23:23:17 +00:00
parent 68399c1ae3
commit 81707bfbfa
8 changed files with 12 additions and 12 deletions

View File

@ -157,7 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_index: None,
adapter_id: None,
})
.collect();

View File

@ -135,7 +135,7 @@ message Request {
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional uint32 adapter_index = 11;
optional string adapter_id = 11;
}
message Batch {

View File

@ -177,7 +177,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_index: None,
adapter_id: None,
});
n_tokens += max_input_length;

View File

@ -244,7 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_index: None,
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,

View File

@ -429,7 +429,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_index: None,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -351,7 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
adapter_index: entry.request.adapter_index,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -492,7 +492,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_index: None,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters {
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_index: Option<u32>,
pub adapter_id: Option<String>,
}
fn default_max_new_tokens() -> Option<u32> {
@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters {
seed: None,
top_n_tokens: None,
grammar: None,
adapter_index: None,
adapter_id: None,
}
}

View File

@ -202,7 +202,7 @@ impl Validation {
decoder_input_details,
top_n_tokens,
grammar,
adapter_index,
adapter_id,
..
} = request.parameters;
@ -384,7 +384,7 @@ impl Validation {
parameters,
stopping_parameters,
top_n_tokens,
adapter_index,
adapter_id,
})
}
@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
pub adapter_index: Option<u32>,
pub adapter_id: Option<String>,
}
#[derive(Error, Debug)]