feat(backend): fix memory leaking on llama_sampler when the decode ends
This commit is contained in:
parent
86a2ae6ba2
commit
7b0a56f40f
|
@ -29,7 +29,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
batch.logits[batch.n_tokens] = true;
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_sampler> sampling_params_t::into_llama_sampler(const llama_model *model) const {
|
||||
llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const {
|
||||
auto *pSampler = llama_sampler_chain_init({.no_perf = false});
|
||||
|
||||
// Penalties
|
||||
|
@ -51,7 +51,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
}
|
||||
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
|
||||
return std::unique_ptr<llama_sampler>(pSampler);
|
||||
return llama_sampler_ptr(pSampler, llama_sampler_deleter);
|
||||
}
|
||||
|
||||
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms)
|
||||
|
|
|
@ -24,7 +24,10 @@
|
|||
namespace huggingface::tgi::backends::llamacpp {
|
||||
|
||||
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
||||
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
|
||||
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_ptr;
|
||||
|
||||
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
||||
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
||||
|
||||
typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
|
||||
static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
|
||||
|
@ -51,7 +54,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @param Pointer to the model data
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<llama_sampler> into_llama_sampler(const llama_model *pModel) const;
|
||||
llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -155,7 +158,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
|
||||
class single_worker_backend_t : backend_base_t {
|
||||
private:
|
||||
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr {
|
||||
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
|
||||
auto llParams = llama_context_default_params();
|
||||
llParams.flash_attn = true;
|
||||
llParams.n_batch = 1;
|
||||
|
@ -165,7 +168,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
|
||||
};
|
||||
|
||||
llama_context_smart_ptr mContext_;
|
||||
llama_context_ptr mContext_;
|
||||
worker_t mWorker_;
|
||||
|
||||
public:
|
||||
|
@ -185,7 +188,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
|
||||
class multi_worker_backend_t : backend_base_t {
|
||||
private:
|
||||
llama_context_smart_ptr mContext_;
|
||||
llama_context_ptr mContext_;
|
||||
|
||||
public:
|
||||
std::expected<size_t, backend_error_t> generate(
|
||||
|
|
Loading…
Reference in New Issue