forward tgi parameters rep/freq penalty
This commit is contained in:
parent
95847c6587
commit
b643a436f3
|
@ -42,6 +42,8 @@ namespace huggingface::tgi::backends {
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
|
* @param repetition_penalty
|
||||||
|
* @param frequency_penalty
|
||||||
* @param seed
|
* @param seed
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -49,6 +51,8 @@ namespace huggingface::tgi::backends {
|
||||||
uint32_t topK,
|
uint32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
uint64_t seed
|
uint64_t seed
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -84,6 +88,8 @@ namespace huggingface::tgi::backends {
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
|
* @param repetition_penalty
|
||||||
|
* @param frequency_penalty
|
||||||
* @param seed
|
* @param seed
|
||||||
* @return Request id related to this generation for reference
|
* @return Request id related to this generation for reference
|
||||||
*/
|
*/
|
||||||
|
@ -92,6 +98,8 @@ namespace huggingface::tgi::backends {
|
||||||
int32_t topK,
|
int32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
uint64_t seed
|
uint64_t seed
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -40,12 +40,15 @@ namespace huggingface::tgi::backends {
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
|
* @param repetition_penalty
|
||||||
|
* @param frequency_penalty
|
||||||
* @param seed
|
* @param seed
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
uint64_t
|
uint64_t
|
||||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
|
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
||||||
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
|
|
|
@ -57,6 +57,8 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
uint32_t topK,
|
uint32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
uint64_t seed) {
|
uint64_t seed) {
|
||||||
return tle::SamplingConfig(
|
return tle::SamplingConfig(
|
||||||
1, // TGI only use a single beam
|
1, // TGI only use a single beam
|
||||||
|
@ -66,9 +68,12 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
seed,
|
seed,
|
||||||
std::nullopt,
|
|
||||||
temperature,
|
temperature,
|
||||||
std::nullopt
|
temperature,
|
||||||
|
std::nullopt,
|
||||||
|
repetition_penalty,
|
||||||
|
std::nullopt,
|
||||||
|
frequency_penalty
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,6 +104,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const int32_t topK,
|
const int32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
const float_t temperature,
|
const float_t temperature,
|
||||||
|
const float_t repetition_penalty,
|
||||||
|
const float_t frequency_penalty,
|
||||||
const uint64_t seed
|
const uint64_t seed
|
||||||
) {
|
) {
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
|
@ -118,7 +125,7 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
||||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, seed);
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
const auto output = tle::OutputConfig(true, false, false, true, false);
|
||||||
return executor.enqueueRequest(
|
return executor.enqueueRequest(
|
||||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||||
|
|
|
@ -140,6 +140,8 @@ impl TensorRtLlmBackend {
|
||||||
top_k: u32,
|
top_k: u32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
|
repetition_penalty: f32,
|
||||||
|
frequency_penalty: f32,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
) {
|
) {
|
||||||
let tokenizer = Arc::clone(&self.tokenizer);
|
let tokenizer = Arc::clone(&self.tokenizer);
|
||||||
|
@ -174,10 +176,15 @@ impl TensorRtLlmBackend {
|
||||||
.in_scope(|| async {
|
.in_scope(|| async {
|
||||||
debug!("Acquiring lock for submit");
|
debug!("Acquiring lock for submit");
|
||||||
let mut handle = executor.write().await;
|
let mut handle = executor.write().await;
|
||||||
let request_id =
|
let request_id = handle.pin_mut().submit(
|
||||||
handle
|
&tokens,
|
||||||
.pin_mut()
|
top_k as i32,
|
||||||
.submit(&tokens, top_k as i32, top_p, temperature, seed);
|
top_p,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
seed,
|
||||||
|
);
|
||||||
|
|
||||||
debug!("Releasing lock for submit");
|
debug!("Releasing lock for submit");
|
||||||
request_id
|
request_id
|
||||||
|
|
|
@ -24,11 +24,13 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) {
|
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty, uint64_t seed) {
|
||||||
|
|
||||||
// This will copy all the items from the initial slice
|
// This will copy all the items from the initial slice
|
||||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
||||||
return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed);
|
return TensorRtLlmBackend::Submit(
|
||||||
|
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||||
|
|
|
@ -50,6 +50,8 @@ mod ffi {
|
||||||
top_k: i32,
|
top_k: i32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
|
repetition_penalty: f32,
|
||||||
|
frequency_penalty: f32,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
) -> u64;
|
) -> u64;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue