From 612f2f939f2b40d76db4a77032695fb90e1fd084 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 1 Nov 2024 00:50:42 +0100 Subject: [PATCH] feat(backend): bind incoming request to the server --- backends/llamacpp/src/backend.rs | 158 +++++++++++++++++++++++++------ backends/llamacpp/src/lib.rs | 2 + 2 files changed, 129 insertions(+), 31 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 6e9e8d2d..670f4397 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -2,18 +2,54 @@ use crate::ffi::{ create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, }; use async_trait::async_trait; -use cxx::{Exception, UniquePtr}; +use cxx::UniquePtr; use std::path::{Path, PathBuf}; +use std::sync::mpsc::{channel, Receiver, SendError, Sender}; use std::sync::Arc; -use std::thread::spawn; +use std::thread::{spawn, JoinHandle}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; -use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::validation::{ + ValidGenerateRequest, ValidParameters, ValidStoppingParameters, +}; +use text_generation_router::Token; use thiserror::Error; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::TryAcquireError; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::info; +use tracing::{error, info}; unsafe impl Send for LlamaCppBackendImpl {} +impl From<&ValidParameters> for SamplingParams { + fn from(v: &ValidParameters) -> Self { + Self { + top_k: v.top_k, + top_p: v.top_p, + frequency_penalty: v.frequency_penalty, + repetition_penalty: v.repetition_penalty, + seed: v.seed, + } + } +} + +impl From<&ValidStoppingParameters> for GenerationParams { + fn from(v: &ValidStoppingParameters) -> Self { + Self { + max_new_tokens: v.max_new_tokens, + ignore_eos_token: v.ignore_eos_token, + } + } +} + +#[cfg_attr(debug_assertions, derive(Debug))] +struct InferContext { + pub(crate) stream: UnboundedSender>, + pub(crate) input_tokens: Arc>, + pub(crate) generated_tokens: Vec, + pub(crate) generation_params: GenerationParams, + pub(crate) sampling_params: SamplingParams, +} + #[derive(Debug, Error)] pub enum LlamaCppBackendError { #[error("Provided GGUF model path {0} doesn't exist")] @@ -23,7 +59,10 @@ pub enum LlamaCppBackendError { ModelInitializationFailed(PathBuf, String), } -pub struct LlamaCppBackend {} +pub struct LlamaCppBackend { + backlog: Sender, + scheduler_handle: JoinHandle<()>, +} impl LlamaCppBackend { pub fn new + Send>(model_path: P) -> Result { @@ -34,7 +73,7 @@ impl LlamaCppBackend { )); } - let mut backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| { + let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| { LlamaCppBackendError::ModelInitializationFailed( path.to_path_buf(), err.what().to_string(), @@ -46,33 +85,67 @@ impl LlamaCppBackend { path.display() ); - let j = spawn(|| scheduler_loop(backend)); - j.join().ok(); - Ok(Self {}) + let (submitter, receiver) = channel(); + let handle = spawn(|| scheduler_loop(backend, receiver)); + Ok(Self { + backlog: submitter, + scheduler_handle: handle, + }) } } -fn scheduler_loop(mut backend: UniquePtr) { - println!("Scheduler loop"); - let tokens = [128000u32, 5159, 836, 374, 23809]; - let mut generated = vec![0u32; 16]; - let generation_params = GenerationParams { - max_new_tokens: generated.len() as u32, - }; - let sampling_params = SamplingParams::default(); - - match backend.pin_mut().generate( - &tokens, - &mut generated, - &generation_params, - &sampling_params, - |new_token_id: u32, is_eos: bool| println!("Generated {new_token_id} (is_eos: {is_eos})"), - ) { - Ok(n_tokens) => { - generated.truncate(n_tokens); - println!("Generated {} tokens -> {:?}", n_tokens, generated); +fn scheduler_loop( + mut backend: UniquePtr, + mut backlog: Receiver, +) { + loop { + println!("Looping"); + if let Ok(mut ctx) = backlog.recv() { + println!("{ctx:?}, {}", &ctx.generated_tokens.capacity()); + match backend.pin_mut().generate( + &ctx.input_tokens, + &mut ctx.generated_tokens, + &ctx.generation_params, + &ctx.sampling_params, + |new_token_id: u32, new_token_logit: f32, is_eos: bool| { + let response = InferStreamResponse::Intermediate { + token: Token { + id: new_token_id, + text: "".to_string(), + logprob: new_token_logit, + special: false, + }, + top_tokens: vec![], + }; + println!("Generated token: {response:?}"); + // let _ = tokio::spawn(async { + // match ctx.stream.send(Ok(response)) { + // Ok(_) => {} + // Err(ref err) => { + // error!( + // "Failed to send back token to the client: {}", + // err.to_string() + // ); + // } + // } + // }); + }, + ) { + Ok(n_tokens) => { + unsafe { + ctx.generated_tokens.set_len(n_tokens); + } + println!( + "Generated {} tokens -> {:?}", + n_tokens, &ctx.generated_tokens + ); + } + Err(err) => println!("Error: {}", err), + } + } else { + info!("IPC channel is closed, exiting the scheduler loop"); + break; } - Err(err) => println!("Error: {}", err), } } @@ -80,9 +153,32 @@ fn scheduler_loop(mut backend: UniquePtr) { impl Backend for LlamaCppBackend { fn schedule( &self, - _request: ValidGenerateRequest, + request: ValidGenerateRequest, ) -> Result>, InferError> { - Err(InferError::GenerationError("Not implemented yet".into())) + if let Some(input_ids) = request.input_ids { + let (sx, rx) = unbounded_channel(); + let sampling_params = SamplingParams::from(&request.parameters); + let generation_params = GenerationParams::from(&request.stopping_parameters); + + let ctx = InferContext { + stream: sx, + input_tokens: Arc::clone(&input_ids), + generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize), + generation_params, + sampling_params, + }; + + match self.backlog.send(ctx) { + Ok(_) => Ok(UnboundedReceiverStream::new(rx)), + Err(_) => Err(InferError::GenerationError( + "Failed to sent the request".to_string(), + )), + } + } else { + Err(InferError::GenerationError( + "Unsupported modalities".to_string(), + )) + } } async fn health(&self, _: bool) -> bool { diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 8d51a15a..489188c1 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -16,11 +16,13 @@ impl Default for SamplingParams { #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] mod ffi { + #[derive(Debug, Copy, Clone)] struct GenerationParams { max_new_tokens: u32, ignore_eos_token: bool, } + #[derive(Debug, Copy, Clone)] struct SamplingParams { top_k: u32, top_p: f32,