(post) impl postprocessing

This commit is contained in:
Morgan Funtowicz 2024-08-26 14:28:44 +00:00 committed by Morgan Funtowicz
parent fa63db0d07
commit 984ae9798f
1 changed files with 66 additions and 13 deletions

View File

@ -3,22 +3,23 @@ use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::{UniquePtr};
use hashbrown::{HashMap};
use cxx::UniquePtr;
use hashbrown::HashMap;
use log::warn;
use tokenizers::{Encoding, Tokenizer};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::task::{JoinHandle, spawn_blocking};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
@ -71,6 +72,8 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
@ -131,10 +134,12 @@ fn executor_status_looper(
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel = DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
@ -148,7 +153,7 @@ fn executor_status_looper(
} else {
warn!("Untracked request {}", step.request_id,);
}
};
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
@ -176,12 +181,60 @@ fn post_processor_looper(
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
let state = states.entry(request_id).or_insert(vec![]);
match decoded {
Ok(ctx) => {
state.push(ctx.token.id);
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let text = tokenizer.decode(&state, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: state.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(err) => {}
}
}
}
}
unsafe impl Send for crate::ffi::TensorRtLlmBackendImpl {}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer,