From b10eaab9f30f7c92ec9d3f73170e69de69c185fa Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 28 Nov 2024 23:57:24 +0100 Subject: [PATCH] feat(backend): use new batch API to generate tokens --- backends/llamacpp/csrc/backend.cpp | 55 +++++++++++++++--------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 00692ea8..f7e4cde2 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -14,10 +14,10 @@ namespace huggingface::tgi::backends::llamacpp { llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const { - auto *pSampler = llama_sampler_chain_init({.no_perf = false}); + auto *sampler = llama_sampler_chain_init({.no_perf = false}); // Penalties - llama_sampler_chain_add(pSampler, llama_sampler_init_penalties( + llama_sampler_chain_add(sampler, llama_sampler_init_penalties( llama_n_vocab(model), llama_token_eos(model), llama_token_nl(model), @@ -28,31 +28,27 @@ namespace huggingface::tgi::backends::llamacpp { false, false )); - - if (top_k > 0) { - llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast(top_k))); - } + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast(top_k))); if (0 < top_p && top_p < 1) { - llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(top_p, 1)); + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(top_p, 0)); } - llama_sampler_chain_add(pSampler, llama_sampler_init_temp(temperature)); - llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); - return {pSampler, llama_sampler_deleter}; + llama_sampler_chain_add(sampler, llama_sampler_init_temp(temperature)); + llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed)); + return {sampler, llama_sampler_deleter}; } - std::expected get_batch_from_prompt(std::span prompt) { auto batch = llama_batch_init(static_cast(prompt.size()), 0, 1); - std::for_each(prompt.begin(), prompt.end(), [&batch](const llama_token token) { - const auto n_token = batch.n_tokens; + batch.n_tokens = 0; - batch.token[n_token] = token; - batch.pos[n_token] = n_token; - batch.n_seq_id[n_token] = 1; - batch.seq_id[n_token][0] = 1; - batch.logits[n_token] = false; + std::for_each(prompt.begin(), prompt.end(), [&batch](const llama_token token) { + batch.token[batch.n_tokens] = token; + batch.pos[batch.n_tokens] = batch.n_tokens; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id[batch.n_tokens][0] = 0; + batch.logits[batch.n_tokens] = false; batch.n_tokens++; }); @@ -60,11 +56,12 @@ namespace huggingface::tgi::backends::llamacpp { return batch; } - void update_batch_for_decoding(llama_batch &batch, llama_token token, size_t position) { - batch.n_tokens = 1; - batch.logits[0] = true; + int32_t update_batch_for_decoding(llama_batch &batch, llama_token token, size_t position) { batch.token[0] = token; batch.pos[0] = static_cast(position); + batch.logits[0] = true; + batch.n_tokens = 1; + return 0; // Decoding will always happen at position 0 } worker_t::worker_t(std::shared_ptr model, const llama_context_params &¶ms) @@ -89,10 +86,14 @@ namespace huggingface::tgi::backends::llamacpp { // Set up the prompt if (auto maybe_batch = get_batch_from_prompt(generation_context.input_tokens); maybe_batch.has_value()) { - // Decode auto batch = *maybe_batch; + + // Keep track of where we are auto n_decoded_tokens = 0; - const auto prompt_size = generation_context.input_tokens.size(); + auto position = batch.n_tokens; + auto sampling_index = batch.n_tokens - 1; + + // Decode for (bool generating = true; generating; ++n_decoded_tokens) { #ifdef TGI_LLAMACPP_BACKEND_DEBUG @@ -104,12 +105,11 @@ namespace huggingface::tgi::backends::llamacpp { #else const auto status = llama_decode(context_.get(), batch); #endif - batch.n_tokens = 0; if (LLAMA_SUCCESS(status)) [[likely]] { // Sample the new token - auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1); + auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), sampling_index); const auto is_eog = llama_token_is_eog(model_.get(), new_token_id); - const auto new_token_logits = llama_get_logits_ith(context_.get(), -1); // TODO: return logit + const auto *new_token_logits = llama_get_logits_ith(context_.get(), sampling_index) + new_token_id; // Handle termination cases const bool has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; @@ -123,7 +123,8 @@ namespace huggingface::tgi::backends::llamacpp { generating = !(should_stop | is_final); // Update the batch for the next generation - update_batch_for_decoding(batch, new_token_id, prompt_size + n_decoded_tokens); + sampling_index = update_batch_for_decoding(batch, new_token_id, position); + position += 1; } }