feat(backend): fix memory leaking on llama_sampler when the decode ends

This commit is contained in:
Morgan Funtowicz 2024-11-03 11:17:02 +01:00
parent 86a2ae6ba2
commit 7b0a56f40f
2 changed files with 10 additions and 7 deletions

View File

@ -29,7 +29,7 @@ namespace huggingface::tgi::backends::llamacpp {
batch.logits[batch.n_tokens] = true;
}
std::unique_ptr<llama_sampler> sampling_params_t::into_llama_sampler(const llama_model *model) const {
llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const {
auto *pSampler = llama_sampler_chain_init({.no_perf = false});
// Penalties
@ -51,7 +51,7 @@ namespace huggingface::tgi::backends::llamacpp {
}
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
return std::unique_ptr<llama_sampler>(pSampler);
return llama_sampler_ptr(pSampler, llama_sampler_deleter);
}
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params)

View File

@ -24,7 +24,10 @@
namespace huggingface::tgi::backends::llamacpp {
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_ptr;
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
@ -51,7 +54,7 @@ namespace huggingface::tgi::backends::llamacpp {
* @param Pointer to the model data
* @return
*/
std::unique_ptr<llama_sampler> into_llama_sampler(const llama_model *pModel) const;
llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const;
};
/**
@ -155,7 +158,7 @@ namespace huggingface::tgi::backends::llamacpp {
class single_worker_backend_t : backend_base_t {
private:
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr {
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
auto llParams = llama_context_default_params();
llParams.flash_attn = true;
llParams.n_batch = 1;
@ -165,7 +168,7 @@ namespace huggingface::tgi::backends::llamacpp {
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
};
llama_context_smart_ptr mContext_;
llama_context_ptr mContext_;
worker_t mWorker_;
public:
@ -185,7 +188,7 @@ namespace huggingface::tgi::backends::llamacpp {
class multi_worker_backend_t : backend_base_t {
private:
llama_context_smart_ptr mContext_;
llama_context_ptr mContext_;
public:
std::expected<size_t, backend_error_t> generate(