feat(backend): fix memory leaking on llama_sampler when the decode ends

This commit is contained in:
Morgan Funtowicz 2024-11-03 11:17:02 +01:00
parent 86a2ae6ba2
commit 7b0a56f40f
2 changed files with 10 additions and 7 deletions

View File

@ -29,7 +29,7 @@ namespace huggingface::tgi::backends::llamacpp {
batch.logits[batch.n_tokens] = true; batch.logits[batch.n_tokens] = true;
} }
std::unique_ptr<llama_sampler> sampling_params_t::into_llama_sampler(const llama_model *model) const { llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const {
auto *pSampler = llama_sampler_chain_init({.no_perf = false}); auto *pSampler = llama_sampler_chain_init({.no_perf = false});
// Penalties // Penalties
@ -51,7 +51,7 @@ namespace huggingface::tgi::backends::llamacpp {
} }
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
return std::unique_ptr<llama_sampler>(pSampler); return llama_sampler_ptr(pSampler, llama_sampler_deleter);
} }
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params) worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params)

View File

@ -24,7 +24,10 @@
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); }; static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr; typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_ptr;
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback; typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {}; static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
@ -51,7 +54,7 @@ namespace huggingface::tgi::backends::llamacpp {
* @param Pointer to the model data * @param Pointer to the model data
* @return * @return
*/ */
std::unique_ptr<llama_sampler> into_llama_sampler(const llama_model *pModel) const; llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const;
}; };
/** /**
@ -155,7 +158,7 @@ namespace huggingface::tgi::backends::llamacpp {
class single_worker_backend_t : backend_base_t { class single_worker_backend_t : backend_base_t {
private: private:
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr { constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
auto llParams = llama_context_default_params(); auto llParams = llama_context_default_params();
llParams.flash_attn = true; llParams.flash_attn = true;
llParams.n_batch = 1; llParams.n_batch = 1;
@ -165,7 +168,7 @@ namespace huggingface::tgi::backends::llamacpp {
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter}; return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
}; };
llama_context_smart_ptr mContext_; llama_context_ptr mContext_;
worker_t mWorker_; worker_t mWorker_;
public: public:
@ -185,7 +188,7 @@ namespace huggingface::tgi::backends::llamacpp {
class multi_worker_backend_t : backend_base_t { class multi_worker_backend_t : backend_base_t {
private: private:
llama_context_smart_ptr mContext_; llama_context_ptr mContext_;
public: public:
std::expected<size_t, backend_error_t> generate( std::expected<size_t, backend_error_t> generate(