diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index b60c3ddc..17709b72 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -38,6 +38,31 @@ namespace huggingface::tgi::backends::llamacpp { return {pSampler, llama_sampler_deleter}; } + + std::expected get_batch_from_prompt(std::span prompt) { + auto batch = llama_batch_init(static_cast(prompt.size()), 0, 1); + std::for_each(prompt.begin(), prompt.end(), [&batch](const llama_token token) { + const auto n_token = batch.n_tokens; + + batch.token[n_token] = token; + batch.pos[n_token] = n_token; + batch.n_seq_id[n_token] = 1; + batch.seq_id[n_token][0] = 1; + batch.logits[n_token] = false; + batch.n_tokens++; + }); + + batch.logits[batch.n_tokens - 1] = true; + return batch; + } + + void update_batch_for_decoding(llama_batch &batch, llama_token token, size_t position) { + batch.n_tokens = 1; + batch.logits[0] = true; + batch.token[0] = token; + batch.pos[0] = static_cast(position); + } + worker_t::worker_t(std::shared_ptr model, const llama_context_params &¶ms) : model_(model), context_(llama_new_context_with_model(model_.get(), params)) { @@ -59,44 +84,50 @@ namespace huggingface::tgi::backends::llamacpp { auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get()); // Set up the prompt - auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end()); - auto batch = llama_batch_get_one(copy.data(), copy.size()); - - // Decode - auto n_decoded_tokens = 0; - for (bool generating = true; generating; ++n_decoded_tokens) { + if (auto maybe_batch = get_batch_from_prompt(generation_context.input_tokens); maybe_batch.has_value()) { + // Decode + auto batch = *maybe_batch; + auto n_decoded_tokens = 0; + const auto prompt_size = generation_context.input_tokens.size(); + for (bool generating = true; generating; ++n_decoded_tokens) { #ifdef TGI_LLAMACPP_BACKEND_DEBUG - const auto start = std::chrono::steady_clock::now(); - const auto status = llama_decode(context_.get(), batch); - const auto end = std::chrono::steady_clock::now(); - const auto latency = std::chrono::duration_cast(end - start); - SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); + const auto start = std::chrono::steady_clock::now(); + const auto status = llama_decode(context_.get(), batch); + const auto end = std::chrono::steady_clock::now(); + const auto latency = std::chrono::duration_cast(end - start); + SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency); #else - const auto status = llama_decode(context_.get(), batch); + const auto status = llama_decode(context_.get(), batch); #endif - batch.n_tokens = 0; - if (LLAMA_SUCCESS(status)) [[likely]] { - // Sample the new token - auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1); - auto is_eog = llama_token_is_eog(model_.get(), new_token_id); - auto new_token_logits = 0.0f; // TODO: return logit + batch.n_tokens = 0; + if (LLAMA_SUCCESS(status)) [[likely]] { + // Sample the new token + auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1); + const auto is_eog = llama_token_is_eog(model_.get(), new_token_id); + const auto new_token_logits = llama_get_logits_ith(context_.get(), -1); // TODO: return logit - // Handle termination cases - const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; - const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog; + // Handle termination cases + const bool has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; + const bool has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog; + const bool is_final = has_reach_max_tokens | has_reach_eog; - generating = !(has_reach_max_tokens | has_reach_eog); + // Bubble up the generated token if a callback is provided + const auto should_stop = callback_(new_token_id, *new_token_logits, is_final, n_decoded_tokens + 1); - // Bubble up the generated token if a callback is provided - const auto should_stop = - std::invoke(callback_, new_token_id, new_token_logits, !generating, n_decoded_tokens + 1); - generating ^= should_stop; + // Compute the continuation flag + generating = !(should_stop | is_final); - batch = llama_batch_get_one(&new_token_id, 1); + // Update the batch for the next generation + update_batch_for_decoding(batch, new_token_id, prompt_size + n_decoded_tokens); + } } - } - return n_decoded_tokens; + llama_batch_free(batch); + + return n_decoded_tokens; + } else { + return maybe_batch.error(); + } } } \ No newline at end of file diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 0e1a13ac..321b667a 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -75,7 +75,7 @@ namespace huggingface::tgi::backends::llamacpp { struct generation_context_t { generation_params_t generation_params; sampling_params_t sampling_params; - std::span input_tokens; + std::span input_tokens; }; /**