feat(backend): somewhat generates the final infer response

This commit is contained in:
Morgan Funtowicz 2024-11-03 00:46:04 +01:00
parent b50dcddbb8
commit 3e82f14f57
1 changed files with 25 additions and 9 deletions

View File

@ -18,7 +18,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::TryAcquireError; use tokio::sync::TryAcquireError;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info}; use tracing::{debug, error, info};
type BoxedOpaqueStream = Box<OpaqueStream>; type BoxedOpaqueStream = Box<OpaqueStream>;
@ -113,7 +113,7 @@ fn llama_generate_callback(
}, },
top_tokens: vec![], top_tokens: vec![],
}; };
println!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}"); debug!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}");
unsafe { unsafe {
if let Err(ref err) = (*channel).0.send(Ok(response)) { if let Err(ref err) = (*channel).0.send(Ok(response)) {
@ -121,7 +121,7 @@ fn llama_generate_callback(
"Failed to send back token to the client: {}", "Failed to send back token to the client: {}",
err.to_string() err.to_string()
); );
} };
} }
} }
@ -131,6 +131,7 @@ unsafe fn scheduler_loop(
) { ) {
loop { loop {
if let Ok(mut ctx) = backlog.recv() { if let Ok(mut ctx) = backlog.recv() {
let start = Instant::now();
let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream)); let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream));
let stream_ptr = Box::into_raw(stream); let stream_ptr = Box::into_raw(stream);
let result = backend.pin_mut().stream( let result = backend.pin_mut().stream(
@ -143,7 +144,7 @@ unsafe fn scheduler_loop(
); );
// Make sure we re-keep track of the OpaqueStream box // Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(stream_ptr); let stream = Box::from_raw(stream_ptr);
match result { match result {
Ok(n_tokens) => { Ok(n_tokens) => {
@ -151,12 +152,27 @@ unsafe fn scheduler_loop(
ctx.generated_tokens.set_len(n_tokens); ctx.generated_tokens.set_len(n_tokens);
} }
println!( let _ = stream.0.send(Ok(InferStreamResponse::End {
"Generated {} tokens -> {:?}", token: Token {
n_tokens, &ctx.generated_tokens id: ctx.generated_tokens[n_tokens - 1],
); text: "".to_string(),
logprob: 0.0,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: "".to_string(),
generated_tokens: n_tokens as u32,
finish_reason: FinishReason::Length,
seed: Some(ctx.sampling_params.seed),
},
start,
queued: start,
}));
debug!("Generated {n_tokens} tokens -> {:?}", ctx.generated_tokens);
} }
Err(err) => println!("Error: {}", err), Err(err) => println!("Error: {err}"),
} }
} else { } else {
info!("IPC channel is closed, exiting the scheduler loop"); info!("IPC channel is closed, exiting the scheduler loop");