feat(backend): use new batch API to generate tokens

This commit is contained in:
Morgan Funtowicz 2024-11-28 23:57:24 +01:00
parent dc6435e3a5
commit b10eaab9f3
1 changed files with 28 additions and 27 deletions

View File

@ -14,10 +14,10 @@
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const { llama_sampler_ptr sampling_params_t::into_llama_sampler(const llama_model *model) const {
auto *pSampler = llama_sampler_chain_init({.no_perf = false}); auto *sampler = llama_sampler_chain_init({.no_perf = false});
// Penalties // Penalties
llama_sampler_chain_add(pSampler, llama_sampler_init_penalties( llama_sampler_chain_add(sampler, llama_sampler_init_penalties(
llama_n_vocab(model), llama_n_vocab(model),
llama_token_eos(model), llama_token_eos(model),
llama_token_nl(model), llama_token_nl(model),
@ -28,31 +28,27 @@ namespace huggingface::tgi::backends::llamacpp {
false, false,
false false
)); ));
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(static_cast<int32_t>(top_k)));
if (top_k > 0) {
llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast<int32_t>(top_k)));
}
if (0 < top_p && top_p < 1) { if (0 < top_p && top_p < 1) {
llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(top_p, 1)); llama_sampler_chain_add(sampler, llama_sampler_init_top_p(top_p, 0));
} }
llama_sampler_chain_add(pSampler, llama_sampler_init_temp(temperature)); llama_sampler_chain_add(sampler, llama_sampler_init_temp(temperature));
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed)); llama_sampler_chain_add(sampler, llama_sampler_init_dist(seed));
return {pSampler, llama_sampler_deleter}; return {sampler, llama_sampler_deleter};
} }
std::expected<llama_batch, backend_error_t> get_batch_from_prompt(std::span<llama_token> prompt) { std::expected<llama_batch, backend_error_t> get_batch_from_prompt(std::span<llama_token> prompt) {
auto batch = llama_batch_init(static_cast<int32_t>(prompt.size()), 0, 1); auto batch = llama_batch_init(static_cast<int32_t>(prompt.size()), 0, 1);
std::for_each(prompt.begin(), prompt.end(), [&batch](const llama_token token) { batch.n_tokens = 0;
const auto n_token = batch.n_tokens;
batch.token[n_token] = token; std::for_each(prompt.begin(), prompt.end(), [&batch](const llama_token token) {
batch.pos[n_token] = n_token; batch.token[batch.n_tokens] = token;
batch.n_seq_id[n_token] = 1; batch.pos[batch.n_tokens] = batch.n_tokens;
batch.seq_id[n_token][0] = 1; batch.n_seq_id[batch.n_tokens] = 1;
batch.logits[n_token] = false; batch.seq_id[batch.n_tokens][0] = 0;
batch.logits[batch.n_tokens] = false;
batch.n_tokens++; batch.n_tokens++;
}); });
@ -60,11 +56,12 @@ namespace huggingface::tgi::backends::llamacpp {
return batch; return batch;
} }
void update_batch_for_decoding(llama_batch &batch, llama_token token, size_t position) { int32_t update_batch_for_decoding(llama_batch &batch, llama_token token, size_t position) {
batch.n_tokens = 1;
batch.logits[0] = true;
batch.token[0] = token; batch.token[0] = token;
batch.pos[0] = static_cast<int32_t>(position); batch.pos[0] = static_cast<int32_t>(position);
batch.logits[0] = true;
batch.n_tokens = 1;
return 0; // Decoding will always happen at position 0
} }
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &&params) worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &&params)
@ -89,10 +86,14 @@ namespace huggingface::tgi::backends::llamacpp {
// Set up the prompt // Set up the prompt
if (auto maybe_batch = get_batch_from_prompt(generation_context.input_tokens); maybe_batch.has_value()) { if (auto maybe_batch = get_batch_from_prompt(generation_context.input_tokens); maybe_batch.has_value()) {
// Decode
auto batch = *maybe_batch; auto batch = *maybe_batch;
// Keep track of where we are
auto n_decoded_tokens = 0; auto n_decoded_tokens = 0;
const auto prompt_size = generation_context.input_tokens.size(); auto position = batch.n_tokens;
auto sampling_index = batch.n_tokens - 1;
// Decode
for (bool generating = true; generating; ++n_decoded_tokens) { for (bool generating = true; generating; ++n_decoded_tokens) {
#ifdef TGI_LLAMACPP_BACKEND_DEBUG #ifdef TGI_LLAMACPP_BACKEND_DEBUG
@ -104,12 +105,11 @@ namespace huggingface::tgi::backends::llamacpp {
#else #else
const auto status = llama_decode(context_.get(), batch); const auto status = llama_decode(context_.get(), batch);
#endif #endif
batch.n_tokens = 0;
if (LLAMA_SUCCESS(status)) [[likely]] { if (LLAMA_SUCCESS(status)) [[likely]] {
// Sample the new token // Sample the new token
auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1); auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), sampling_index);
const auto is_eog = llama_token_is_eog(model_.get(), new_token_id); const auto is_eog = llama_token_is_eog(model_.get(), new_token_id);
const auto new_token_logits = llama_get_logits_ith(context_.get(), -1); // TODO: return logit const auto *new_token_logits = llama_get_logits_ith(context_.get(), sampling_index) + new_token_id;
// Handle termination cases // Handle termination cases
const bool has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; const bool has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
@ -123,7 +123,8 @@ namespace huggingface::tgi::backends::llamacpp {
generating = !(should_stop | is_final); generating = !(should_stop | is_final);
// Update the batch for the next generation // Update the batch for the next generation
update_batch_for_decoding(batch, new_token_id, prompt_size + n_decoded_tokens); sampling_index = update_batch_for_decoding(batch, new_token_id, position);
position += 1;
} }
} }