diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 46052435..00692ea8 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -37,6 +37,7 @@ namespace huggingface::tgi::backends::llamacpp { llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(top_p, 1)); } + llama_sampler_chain_add(pSampler, llama_sampler_init_temp(temperature)); llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); return {pSampler, llama_sampler_deleter}; } diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 321b667a..38fd3aad 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -48,6 +48,7 @@ namespace huggingface::tgi::backends::llamacpp { float_t top_p = 1.0f; float_t frequency_penalty = 0.0f; float_t repetition_penalty = 0.0f; + float_t temperature = 0.0f; uint64_t seed = 2014; /** diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index d8f28ab9..e1575b1d 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -104,6 +104,7 @@ impl From<&ValidParameters> for SamplingParams { top_p: v.top_p, frequency_penalty: v.frequency_penalty, repetition_penalty: v.repetition_penalty, + temperature: v.temperature, seed: v.seed, } } diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index d844bb9f..3507217f 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -10,6 +10,7 @@ impl Default for SamplingParams { top_p: 1.0f32, frequency_penalty: 0.0f32, repetition_penalty: 0.0f32, + temperature: 1.0f32, seed: 2014u64, } } @@ -29,6 +30,7 @@ mod ffi { top_p: f32, frequency_penalty: f32, repetition_penalty: f32, + temperature: f32, seed: u64, }