forward tgi parameters rep/freq penalty

This commit is contained in:
Morgan Funtowicz 2024-07-18 20:56:58 +00:00
parent 95847c6587
commit b643a436f3
6 changed files with 39 additions and 10 deletions

View File

@ -42,6 +42,8 @@ namespace huggingface::tgi::backends {
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed * @param seed
* @return * @return
*/ */
@ -49,6 +51,8 @@ namespace huggingface::tgi::backends {
uint32_t topK, uint32_t topK,
float_t topP, float_t topP,
float_t temperature, float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed uint64_t seed
); );
@ -84,6 +88,8 @@ namespace huggingface::tgi::backends {
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed * @param seed
* @return Request id related to this generation for reference * @return Request id related to this generation for reference
*/ */
@ -92,6 +98,8 @@ namespace huggingface::tgi::backends {
int32_t topK, int32_t topK,
float_t topP, float_t topP,
float_t temperature, float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed uint64_t seed
); );

View File

@ -40,12 +40,15 @@ namespace huggingface::tgi::backends {
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed * @param seed
* @return * @return
*/ */
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]] [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed); Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/*** /***
* *

View File

@ -57,6 +57,8 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK, uint32_t topK,
float_t topP, float_t topP,
float_t temperature, float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) { uint64_t seed) {
return tle::SamplingConfig( return tle::SamplingConfig(
1, // TGI only use a single beam 1, // TGI only use a single beam
@ -66,9 +68,12 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
std::nullopt, std::nullopt,
std::nullopt, std::nullopt,
seed, seed,
std::nullopt,
temperature, temperature,
std::nullopt temperature,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty
); );
} }
@ -99,6 +104,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const int32_t topK, const int32_t topK,
const float_t topP, const float_t topP,
const float_t temperature, const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed const uint64_t seed
) { ) {
#ifdef NDEBUG #ifdef NDEBUG
@ -118,7 +125,7 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>(); const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size())); const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, seed); const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false); const auto output = tle::OutputConfig(true, false, false, true, false);
return executor.enqueueRequest( return executor.enqueueRequest(
tle::Request{tokens, maxNewTokens, true, sampling, output}); tle::Request{tokens, maxNewTokens, true, sampling, output});

View File

@ -140,6 +140,8 @@ impl TensorRtLlmBackend {
top_k: u32, top_k: u32,
top_p: f32, top_p: f32,
temperature: f32, temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64, seed: u64,
) { ) {
let tokenizer = Arc::clone(&self.tokenizer); let tokenizer = Arc::clone(&self.tokenizer);
@ -174,10 +176,15 @@ impl TensorRtLlmBackend {
.in_scope(|| async { .in_scope(|| async {
debug!("Acquiring lock for submit"); debug!("Acquiring lock for submit");
let mut handle = executor.write().await; let mut handle = executor.write().await;
let request_id = let request_id = handle.pin_mut().submit(
handle &tokens,
.pin_mut() top_k as i32,
.submit(&tokens, top_k as i32, top_p, temperature, seed); top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
debug!("Releasing lock for submit"); debug!("Releasing lock for submit");
request_id request_id

View File

@ -24,11 +24,13 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
} }
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) { rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
float_t frequency_penalty, uint64_t seed) {
// This will copy all the items from the initial slice // This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed); return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
} }
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(

View File

@ -50,6 +50,8 @@ mod ffi {
top_k: i32, top_k: i32,
top_p: f32, top_p: f32,
temperature: f32, temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64, seed: u64,
) -> u64; ) -> u64;