From 984ae9798f8cb19863282c0145501d329146d5d1 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 26 Aug 2024 14:28:44 +0000 Subject: [PATCH] (post) impl postprocessing --- backends/trtllm/src/looper.rs | 79 +++++++++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 13 deletions(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 4247f338..ba10d9ee 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -3,22 +3,23 @@ use std::ops::Deref; use std::path::Path; use async_trait::async_trait; -use cxx::{UniquePtr}; -use hashbrown::{HashMap}; +use cxx::UniquePtr; +use hashbrown::HashMap; use log::warn; use tokenizers::{Encoding, Tokenizer}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -use tokio::task::{spawn_blocking, JoinHandle}; +use tokio::task::{JoinHandle, spawn_blocking}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error}; +use text_generation_router::{FinishReason, Token}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; -use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use text_generation_router::validation::ValidationError::{ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, }; -use text_generation_router::validation::{Chunk, ValidGenerateRequest}; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; @@ -71,6 +72,8 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { /// Wraps the decoded token with the channel used to stream back to the client the decoded tokens struct DecodedTokenContext { token: DecodedToken, + start: Option, + queued: Instant, channel: UnboundedSender>, } @@ -131,12 +134,14 @@ fn executor_status_looper( // Iterate through all the decoded token for step in responses.deref() { if let Some(ctx) = in_flights.get(&step.request_id) { - // Remove from tracked requests - let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext { - token: dt, - channel: ctx.streamer.clone(), - }); + let parcel = + DecodedToken::try_from(step).map(|dt| DecodedTokenContext { + token: dt, + start: ctx.start, + queued: ctx.queued, + channel: ctx.streamer.clone(), + }); // Submit the work to p:the post_processor let posted = post_processor_sender.send((step.request_id, parcel)); @@ -148,7 +153,7 @@ fn executor_status_looper( } else { warn!("Untracked request {}", step.request_id,); } - }; + } } Err(ref err) => { error!("Failed to get responses from the executor: {}.", err.what()); @@ -176,12 +181,60 @@ fn post_processor_looper( if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { let state = states.entry(request_id).or_insert(vec![]); + + match decoded { + Ok(ctx) => { + state.push(ctx.token.id); + let out = match tokenizer.decode(&[ctx.token.id], false) { + Ok(text) => { + let is_special = + tokenizer.get_added_vocabulary().is_special_token(&text); + let token = Token { + id: ctx.token.id, + text, + logprob: ctx.token.log_prob, + special: is_special, + }; + + let out = if !ctx.token.is_final { + InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + } + } else { + let text = tokenizer.decode(&state, true); + let generated_text = GeneratedText { + text: text.unwrap(), + generated_tokens: state.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }; + + InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text, + start: ctx.start.unwrap(), + queued: ctx.queued, + } + }; + + Ok(out) + } + Err(err) => Err(GenerationError(err.to_string())), + }; + + if let Err(_) = ctx.channel.send(out) { + warn!("Failed to send decoded token back to the user") + } + } + Err(err) => {} + } } } } - -unsafe impl Send for crate::ffi::TensorRtLlmBackendImpl {} +unsafe impl Send for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2 { tokenizer: Tokenizer,