diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 557e03cb..d3d6bc59 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -10,10 +10,12 @@ use crate::{ }; use async_stream::stream; use async_trait::async_trait; +use axum::response::sse::Event; use chat_template::ChatTemplate; use futures::future::try_join_all; use futures::Stream; use minijinja::ErrorKind; +use serde::Serialize; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; @@ -373,4 +375,26 @@ impl InferError { InferError::StreamSerializationError(_) => "stream_serialization_error", } } + + pub(crate) fn into_openai_event(self) -> Event { + Event::default() + .json_data(OpenaiErrorEvent { + error: APIError { + message: self.to_string(), + http_status_code: 422, + }, + }) + .unwrap() + } +} + +#[derive(Serialize)] +pub struct APIError { + message: String, + http_status_code: usize, +} + +#[derive(Serialize)] +pub struct OpenaiErrorEvent { + error: APIError, } diff --git a/router/src/server.rs b/router/src/server.rs index a0bc1768..cbb04174 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -866,7 +866,7 @@ pub(crate) async fn completions( yield Ok(event); } - Err(err) => yield Ok(Event::from(err)), + Err(err) => yield Ok(err.into_openai_event()), } } }; @@ -1274,7 +1274,8 @@ pub(crate) async fn chat_completions( }; let mut response_as_tool = using_tools; while let Some(result) = response_stream.next().await { - if let Ok(stream_token) = result { + match result{ + Ok(stream_token) => { let token_text = &stream_token.token.text.clone(); match state { StreamState::Buffering => { @@ -1368,6 +1369,8 @@ pub(crate) async fn chat_completions( } } } + Err(err) => yield Ok(err.into_openai_event()) + } } yield Ok::(Event::default().data("[DONE]")); };