fix: include rust code for adapter id
This commit is contained in:
parent
68399c1ae3
commit
81707bfbfa
|
@ -157,7 +157,7 @@ async fn prefill(
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
adapter_index: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
|
@ -135,7 +135,7 @@ message Request {
|
||||||
/// Paged attention slots
|
/// Paged attention slots
|
||||||
repeated uint32 slots = 10;
|
repeated uint32 slots = 10;
|
||||||
/// LORA adapter index
|
/// LORA adapter index
|
||||||
optional uint32 adapter_index = 11;
|
optional string adapter_id = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -177,7 +177,7 @@ impl Client {
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_index: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
|
|
@ -244,7 +244,7 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
adapter_index: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
|
|
@ -429,7 +429,7 @@ mod tests {
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
adapter_index: None,
|
adapter_id: None,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
|
|
@ -351,7 +351,7 @@ impl State {
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
adapter_index: entry.request.adapter_index,
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
|
@ -492,7 +492,7 @@ mod tests {
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
adapter_index: None,
|
adapter_id: None,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
|
|
@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters {
|
||||||
/// Lora adapter id
|
/// Lora adapter id
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub adapter_index: Option<u32>,
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
|
@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
adapter_index: None,
|
adapter_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -202,7 +202,7 @@ impl Validation {
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
grammar,
|
grammar,
|
||||||
adapter_index,
|
adapter_id,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ impl Validation {
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
adapter_index,
|
adapter_id,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||||
pub parameters: ValidParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: ValidStoppingParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
pub top_n_tokens: u32,
|
pub top_n_tokens: u32,
|
||||||
pub adapter_index: Option<u32>,
|
pub adapter_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
|
|
Loading…
Reference in New Issue