make sure the context is not dropped in the middle of the async decoding.
This commit is contained in:
parent
9220340ff7
commit
e983ee5bb8
|
@ -6,19 +6,19 @@ authors.workspace = true
|
||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3"
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.14"
|
tokio-stream = "0.1.15"
|
||||||
clap = { version = "4.5.4", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
thiserror = "1.0.61"
|
thiserror = "1.0.62"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-opentelemetry = "0.24"
|
tracing-opentelemetry = "0.24"
|
||||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
log = { version = "0.4.21", features = [] }
|
log = { version = "0.4", features = [] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
|
|
@ -50,8 +50,7 @@ namespace huggingface::tgi::backends {
|
||||||
uint32_t topK,
|
uint32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
uint64_t seed,
|
uint64_t seed
|
||||||
std::optional<int32_t> beamWidth
|
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -56,8 +56,8 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
size_t StreamTokens(
|
size_t StreamTokens(
|
||||||
const RequestId requestId,
|
const RequestId requestId,
|
||||||
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
huggingface::tgi::backends::GenerationContext *ctx,
|
||||||
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback);
|
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback);
|
||||||
};
|
};
|
||||||
|
|
||||||
/***
|
/***
|
||||||
|
|
|
@ -57,10 +57,9 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
uint32_t topK,
|
uint32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
uint64_t seed,
|
uint64_t seed) {
|
||||||
std::optional<int32_t> beamWidth = std::nullopt) {
|
|
||||||
return tle::SamplingConfig(
|
return tle::SamplingConfig(
|
||||||
beamWidth.value_or(1),
|
1, // TGI only use a single beam
|
||||||
topK,
|
topK,
|
||||||
topP,
|
topP,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
|
@ -116,11 +115,11 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
);
|
);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const auto maxNumTokens = config["max_num_tokens"_json_pointer].get<size_t>();
|
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
||||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, seed);
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, seed);
|
||||||
const auto output = tle::OutputConfig(false, false, false, true, false);
|
const auto output = tle::OutputConfig(true, false, false, true, false);
|
||||||
return executor.enqueueRequest(
|
return executor.enqueueRequest(
|
||||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use log::{info, warn};
|
use log::{debug, info, warn};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
@ -19,7 +19,8 @@ use tracing::{instrument, Level, span};
|
||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
||||||
|
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||||
|
@ -55,10 +56,12 @@ pub(crate) struct Generation {
|
||||||
done: Arc<AtomicBool>,
|
done: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct GenerationContext(
|
#[derive(Clone)]
|
||||||
UnboundedSender<InferResult<InferStreamResponse>>,
|
pub struct GenerationContext {
|
||||||
Arc<AtomicBool>,
|
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
);
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
done: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
impl Stream for Generation {
|
impl Stream for Generation {
|
||||||
type Item = usize;
|
type Item = usize;
|
||||||
|
@ -110,24 +113,160 @@ unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||||
pub struct TensorRtLlmBackend {
|
pub struct TensorRtLlmBackend {
|
||||||
// Allowing sending user requests to the TensorRT-LLM executor thread
|
tokenizer: Arc<Tokenizer>,
|
||||||
// batcher: UnboundedSender<InferenceContext>,
|
|
||||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TensorRtLlmBackend {
|
impl TensorRtLlmBackend {
|
||||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||||
_tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
engine_folder: P,
|
engine_folder: P,
|
||||||
_executor_worker_path: Option<PP>,
|
_executor_worker_path: Option<PP>,
|
||||||
) -> Result<Self, TensorRtLlmBackendError> {
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
Ok(TensorRtLlmBackend {
|
Ok(TensorRtLlmBackend {
|
||||||
|
tokenizer: Arc::new(tokenizer),
|
||||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||||
engine_folder.as_ref().to_str().unwrap(),
|
engine_folder.as_ref().to_str().unwrap(),
|
||||||
"",
|
"",
|
||||||
))),
|
))),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||||
|
if request.top_n_tokens > 1 {
|
||||||
|
return Err(InferError::ValidationError(
|
||||||
|
ValidationError::TopNTokensDisabled,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
match request.inputs.len() {
|
||||||
|
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
||||||
|
2.. => Err(InferError::GenerationError(
|
||||||
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
|
)),
|
||||||
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
|
Chunk::Text(text) => Ok(text),
|
||||||
|
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
tokens: Vec<u32>,
|
||||||
|
top_k: u32,
|
||||||
|
top_p: f32,
|
||||||
|
temperature: f32,
|
||||||
|
seed: u64,
|
||||||
|
) {
|
||||||
|
let tokenizer = self.tokenizer.clone();
|
||||||
|
let executor = self.backend.clone();
|
||||||
|
|
||||||
|
// Let's push this in async context
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Define the generation state
|
||||||
|
let mut generation = Generation {
|
||||||
|
executor: executor.clone(),
|
||||||
|
done: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Define the context over the generation
|
||||||
|
// TODO(asap): Do we really need so many shared-ownership?
|
||||||
|
let ctx = Box::new(GenerationContext {
|
||||||
|
sender: sender.clone(),
|
||||||
|
tokenizer: tokenizer.clone(),
|
||||||
|
done: Arc::clone(&generation.done),
|
||||||
|
});
|
||||||
|
|
||||||
|
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||||
|
// still computation ongoing
|
||||||
|
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
||||||
|
let ctx_ = Box::leak(ctx);
|
||||||
|
|
||||||
|
// Submit the request to the batcher
|
||||||
|
let request_id = span!(Level::DEBUG, "submit")
|
||||||
|
.in_scope(|| async {
|
||||||
|
debug!("Acquiring lock for submit");
|
||||||
|
let mut handle = executor.write().await;
|
||||||
|
let request_id =
|
||||||
|
handle
|
||||||
|
.pin_mut()
|
||||||
|
.submit(&tokens, top_k as i32, top_p, temperature, seed);
|
||||||
|
|
||||||
|
debug!("Releasing lock for submit");
|
||||||
|
request_id
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
while let Some(_) = generation.next().await {
|
||||||
|
span!(Level::DEBUG, "decode", request_id = request_id)
|
||||||
|
.in_scope(|| async {
|
||||||
|
let mut executor_w = executor.write().await;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
debug!("Acquired write lock stream");
|
||||||
|
executor_w.pin_mut().stream_tokens(
|
||||||
|
request_id,
|
||||||
|
ctx_,
|
||||||
|
|ctx: *mut GenerationContext,
|
||||||
|
token: u32,
|
||||||
|
logprob: f32,
|
||||||
|
is_final: bool| {
|
||||||
|
// let text = ctx
|
||||||
|
// .tokenizer
|
||||||
|
// .decode(&[token], true)
|
||||||
|
// .expect("Failed to decode token");
|
||||||
|
info!("Decoded token: {}", token);
|
||||||
|
let out = if is_final {
|
||||||
|
(*ctx).done.store(true, Ordering::Relaxed);
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token: Token {
|
||||||
|
id: token,
|
||||||
|
text: "".into(),
|
||||||
|
logprob,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text: "".into(),
|
||||||
|
generated_tokens: u32::MAX,
|
||||||
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
seed: None,
|
||||||
|
},
|
||||||
|
start: Instant::now(),
|
||||||
|
queued: Instant::now(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
InferStreamResponse::Intermediate {
|
||||||
|
token: Token {
|
||||||
|
id: token,
|
||||||
|
text: "".into(),
|
||||||
|
logprob,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(*ctx)
|
||||||
|
.sender
|
||||||
|
.send(Ok(out))
|
||||||
|
.expect("Failed to send back generated token");
|
||||||
|
},
|
||||||
|
);
|
||||||
|
debug!("Releasing write lock stream")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Properly" free the shared context...
|
||||||
|
// TODO: clean that piece of sh** asap
|
||||||
|
unsafe {
|
||||||
|
let _ = Box::from_raw(ctx_);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
@ -135,96 +274,32 @@ impl Backend for TensorRtLlmBackend {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
_request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||||
|
// Let's add a few more validation
|
||||||
|
let input = TensorRtLlmBackend::validate(&request)?;
|
||||||
|
|
||||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||||
let (sender, receiver) = unbounded_channel();
|
let (sender, receiver) = unbounded_channel();
|
||||||
|
|
||||||
let executor = self.backend.clone();
|
// Unpack parameters
|
||||||
tokio::spawn(async move {
|
let params = &request.parameters;
|
||||||
// Submit the request to the batcher
|
|
||||||
let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]")
|
|
||||||
.in_scope(|| async {
|
|
||||||
info!("Acquiring lock for submit");
|
|
||||||
let mut handle = executor.write().await;
|
|
||||||
let request_id = handle.pin_mut().submit(
|
|
||||||
&vec![2, 2926, 1503, 603, 20189],
|
|
||||||
50,
|
|
||||||
1.0,
|
|
||||||
1.0,
|
|
||||||
2014,
|
|
||||||
);
|
|
||||||
|
|
||||||
info!("Releasing lock for submit");
|
// Preprocess the inputs to send to TRTLLM backend
|
||||||
request_id
|
let encoding = self
|
||||||
})
|
.tokenizer
|
||||||
.await;
|
.encode(input.as_str(), true)
|
||||||
|
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
||||||
|
|
||||||
let mut generation = Generation {
|
// Generate the response
|
||||||
executor: executor.clone(),
|
self.generate(
|
||||||
done: Arc::new(AtomicBool::new(false)),
|
sender,
|
||||||
};
|
Vec::from(encoding.get_ids()),
|
||||||
|
params.top_k,
|
||||||
while let Some(num_tokens_ready) = generation.next().await {
|
params.top_p,
|
||||||
span!(
|
params.temperature,
|
||||||
Level::DEBUG,
|
params.seed,
|
||||||
"[EXECUTOR][GENERATE]",
|
);
|
||||||
request_id = request_id,
|
|
||||||
num_tokens_ready = num_tokens_ready
|
|
||||||
)
|
|
||||||
.in_scope(|| async {
|
|
||||||
let ctx = Box::new(GenerationContext(
|
|
||||||
sender.clone(),
|
|
||||||
Arc::clone(&generation.done),
|
|
||||||
));
|
|
||||||
let mut executor_w = executor.write().await;
|
|
||||||
|
|
||||||
info!("Acquired write lock stream");
|
|
||||||
executor_w.pin_mut().stream_tokens(
|
|
||||||
request_id,
|
|
||||||
ctx,
|
|
||||||
|ctx: Box<GenerationContext>, token: u32, logprob: f32, is_final: bool| {
|
|
||||||
info!("Sending token: {} (final: {})", token, is_final);
|
|
||||||
let out = if is_final {
|
|
||||||
ctx.1.store(true, Ordering::Relaxed);
|
|
||||||
InferStreamResponse::End {
|
|
||||||
token: Token {
|
|
||||||
id: token,
|
|
||||||
text: "".into(),
|
|
||||||
logprob,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: "".into(),
|
|
||||||
generated_tokens: u32::MAX,
|
|
||||||
finish_reason: FinishReason::EndOfSequenceToken,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
start: Instant::now(),
|
|
||||||
queued: Instant::now(),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
InferStreamResponse::Intermediate {
|
|
||||||
token: Token {
|
|
||||||
id: token,
|
|
||||||
text: "".into(),
|
|
||||||
logprob,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
}
|
|
||||||
};
|
|
||||||
ctx.0
|
|
||||||
.send(Ok(out))
|
|
||||||
.expect("Failed to send back generated token");
|
|
||||||
},
|
|
||||||
);
|
|
||||||
info!("Releasing write lock stream")
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(UnboundedReceiverStream::new(receiver))
|
Ok(UnboundedReceiverStream::new(receiver))
|
||||||
}
|
}
|
||||||
|
@ -233,79 +308,3 @@ impl Backend for TensorRtLlmBackend {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// async fn background_looper<P: AsRef<Path>, PP: AsRef<Path>>(
|
|
||||||
// engine_folder: P,
|
|
||||||
// _executor_worker: Option<PP>,
|
|
||||||
// tokenizer: Tokenizer,
|
|
||||||
// mut receiver: UnboundedReceiver<InferenceContext>,
|
|
||||||
// ) {
|
|
||||||
// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), "");
|
|
||||||
//
|
|
||||||
// while !(receiver.is_closed()) {
|
|
||||||
// // Receive the incoming request
|
|
||||||
// if let Some(ctx) = receiver.recv().await {
|
|
||||||
// debug!("Processing new incoming request");
|
|
||||||
//
|
|
||||||
// // We only support single, textual chunk
|
|
||||||
// if ctx.request.inputs.len() != 1 {
|
|
||||||
// propagate!(
|
|
||||||
// ctx,
|
|
||||||
// Err(InferError::GenerationError(format!(
|
|
||||||
// "Unsupported multi-chunk ({}) input",
|
|
||||||
// ctx.request.inputs.len()
|
|
||||||
// )))
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// let input = ctx
|
|
||||||
// .request
|
|
||||||
// .inputs
|
|
||||||
// .first()
|
|
||||||
// .expect("Single chunk checked above");
|
|
||||||
// let params = ctx.request.parameters;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Receive the incoming request
|
|
||||||
// if let Some(ctx) = receiver.recv().await {
|
|
||||||
// debug!("Processing new incoming request");
|
|
||||||
|
|
||||||
// // We only support single, textual chunk
|
|
||||||
// if ctx.request.inputs.len() != 1 {
|
|
||||||
// propagate!(
|
|
||||||
// ctx,
|
|
||||||
// Err(InferError::GenerationError(format!(
|
|
||||||
// "Unsupported multi-chunk ({}) input",
|
|
||||||
// ctx.request.inputs.len()
|
|
||||||
// )))
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Unpack parameters
|
|
||||||
// let inputs = ctx.request.inputs;
|
|
||||||
// let params = ctx.request.parameters;
|
|
||||||
//
|
|
||||||
// match inputs.first().unwrap() {
|
|
||||||
// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) {
|
|
||||||
// Err(err) => {
|
|
||||||
// propagate!(ctx, Err(InferError::GenerationError(err.to_string())))
|
|
||||||
// }
|
|
||||||
// Ok(encoding) => {
|
|
||||||
// // spawn_blocking(|| {
|
|
||||||
// // info!("Submitting request to TensorRT-LLM executor");
|
|
||||||
// // let mut executor = backend.blocking_write();
|
|
||||||
// // })
|
|
||||||
// // .await
|
|
||||||
// // .expect("");
|
|
||||||
// }
|
|
||||||
// },
|
|
||||||
// Chunk::Image(_) => propagate!(
|
|
||||||
// ctx,
|
|
||||||
// Err(InferError::GenerationError(
|
|
||||||
// "Image input is not supported yet.".into(),
|
|
||||||
// ))
|
|
||||||
// ),
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
// }
|
|
||||||
|
|
|
@ -33,8 +33,8 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
|
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||||
const uint64_t requestId,
|
const uint64_t requestId,
|
||||||
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
huggingface::tgi::backends::GenerationContext *ctx,
|
||||||
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback) {
|
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback) {
|
||||||
|
|
||||||
size_t numTokens = 0;
|
size_t numTokens = 0;
|
||||||
for (const auto &item: Poll(requestId)) {
|
for (const auto &item: Poll(requestId)) {
|
||||||
|
@ -44,12 +44,12 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||||
|
|
||||||
const auto token = decoded.outputTokenIds[0][0];
|
const auto token = decoded.outputTokenIds[0][0];
|
||||||
const auto isFinal = decoded.isFinal;
|
const auto isFinal = decoded.isFinal;
|
||||||
const auto logProb = decoded.logProbs.value()[0][0];
|
// const auto logProb = decoded.logProbs.value()[0][0];
|
||||||
|
|
||||||
++numTokens;
|
++numTokens;
|
||||||
|
|
||||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||||
callback(std::move(ctx), token, logProb, isFinal);
|
callback(std::move(ctx), token, 1.0, isFinal);
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||||
} else {
|
} else {
|
||||||
// TODO : Return rest::Result with error
|
// TODO : Return rest::Result with error
|
||||||
|
|
|
@ -54,11 +54,11 @@ mod ffi {
|
||||||
) -> u64;
|
) -> u64;
|
||||||
|
|
||||||
#[rust_name = "stream_tokens"]
|
#[rust_name = "stream_tokens"]
|
||||||
fn StreamTokens(
|
unsafe fn StreamTokens(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
ctx: Box<GenerationContext>,
|
ctx: *mut GenerationContext,
|
||||||
cb: fn(Box<GenerationContext>, u32, f32, bool),
|
cb: unsafe fn(*mut GenerationContext, u32, f32, bool),
|
||||||
) -> usize;
|
) -> usize;
|
||||||
|
|
||||||
// #[rust_name = "shutdown"]
|
// #[rust_name = "shutdown"]
|
||||||
|
|
|
@ -777,6 +777,9 @@ pub enum ValidationError {
|
||||||
InvalidImageContent(String),
|
InvalidImageContent(String),
|
||||||
#[error("Could not fetch image: {0}")]
|
#[error("Could not fetch image: {0}")]
|
||||||
FailedFetchImage(#[from] reqwest::Error),
|
FailedFetchImage(#[from] reqwest::Error),
|
||||||
|
#[error("{0} modality is not supported")]
|
||||||
|
UnsupportedModality(&'static str)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
Loading…
Reference in New Issue