feat(backend): fix memory leaking on llama_sampler when the decode ends
This commit is contained in:
parent
86a2ae6ba2
commit
7b0a56f40f
|
@ -29,7 +29,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
batch.logits[batch.n_tokens] = true;
|
batch.logits[batch.n_tokens] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<llama_sampler> sampling_params_t::into_llama_sampler(const llama_model *model) const {
|
llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const {
|
||||||
auto *pSampler = llama_sampler_chain_init({.no_perf = false});
|
auto *pSampler = llama_sampler_chain_init({.no_perf = false});
|
||||||
|
|
||||||
// Penalties
|
// Penalties
|
||||||
|
@ -51,7 +51,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
|
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
|
||||||
return std::unique_ptr<llama_sampler>(pSampler);
|
return llama_sampler_ptr(pSampler, llama_sampler_deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms)
|
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms)
|
||||||
|
|
|
@ -24,7 +24,10 @@
|
||||||
namespace huggingface::tgi::backends::llamacpp {
|
namespace huggingface::tgi::backends::llamacpp {
|
||||||
|
|
||||||
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
||||||
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
|
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_ptr;
|
||||||
|
|
||||||
|
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
||||||
|
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
||||||
|
|
||||||
typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
|
typedef std::function<void(llama_token, float_t, bool)> llama_decode_callback;
|
||||||
static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
|
static constexpr auto llama_void_callback = [](llama_token, float_t, bool) {};
|
||||||
|
@ -51,7 +54,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
* @param Pointer to the model data
|
* @param Pointer to the model data
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
std::unique_ptr<llama_sampler> into_llama_sampler(const llama_model *pModel) const;
|
llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -155,7 +158,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
|
|
||||||
class single_worker_backend_t : backend_base_t {
|
class single_worker_backend_t : backend_base_t {
|
||||||
private:
|
private:
|
||||||
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr {
|
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
|
||||||
auto llParams = llama_context_default_params();
|
auto llParams = llama_context_default_params();
|
||||||
llParams.flash_attn = true;
|
llParams.flash_attn = true;
|
||||||
llParams.n_batch = 1;
|
llParams.n_batch = 1;
|
||||||
|
@ -165,7 +168,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
|
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
|
||||||
};
|
};
|
||||||
|
|
||||||
llama_context_smart_ptr mContext_;
|
llama_context_ptr mContext_;
|
||||||
worker_t mWorker_;
|
worker_t mWorker_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -185,7 +188,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
||||||
|
|
||||||
class multi_worker_backend_t : backend_base_t {
|
class multi_worker_backend_t : backend_base_t {
|
||||||
private:
|
private:
|
||||||
llama_context_smart_ptr mContext_;
|
llama_context_ptr mContext_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::expected<size_t, backend_error_t> generate(
|
std::expected<size_t, backend_error_t> generate(
|
||||||
|
|
Loading…
Reference in New Issue