From 6489f85269ffb91ab1c62c3b76964167206b850a Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 15 Nov 2024 08:49:19 -0500 Subject: [PATCH] feat: return streaming errors as an event formatted for openai's client (#2668) * feat: return streaming errors as an event formatted for openai's client * fix: propagate completions error events to stream * fix: improve stream api error format and add status code * fix: improve streamin error to include error_type * Revert "fix: improve streamin error to include error_type" This reverts commit 2b1a360b1511d94ea9a24e5432e498e67939506a. * Reworked the implementation. * Revert "Reworked the implementation." This reverts commit 7c3f29777f17411ae4ade57e2f88e73cde704ee5. * Small lifting. --------- Co-authored-by: Nicolas Patry --- router/src/infer/mod.rs | 24 ++++++++++++++++++++++++ router/src/server.rs | 7 +++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 557e03cb..d3d6bc59 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -10,10 +10,12 @@ use crate::{ }; use async_stream::stream; use async_trait::async_trait; +use axum::response::sse::Event; use chat_template::ChatTemplate; use futures::future::try_join_all; use futures::Stream; use minijinja::ErrorKind; +use serde::Serialize; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; @@ -373,4 +375,26 @@ impl InferError { InferError::StreamSerializationError(_) => "stream_serialization_error", } } + + pub(crate) fn into_openai_event(self) -> Event { + Event::default() + .json_data(OpenaiErrorEvent { + error: APIError { + message: self.to_string(), + http_status_code: 422, + }, + }) + .unwrap() + } +} + +#[derive(Serialize)] +pub struct APIError { + message: String, + http_status_code: usize, +} + +#[derive(Serialize)] +pub struct OpenaiErrorEvent { + error: APIError, } diff --git a/router/src/server.rs b/router/src/server.rs index a0bc1768..cbb04174 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -866,7 +866,7 @@ pub(crate) async fn completions( yield Ok(event); } - Err(err) => yield Ok(Event::from(err)), + Err(err) => yield Ok(err.into_openai_event()), } } }; @@ -1274,7 +1274,8 @@ pub(crate) async fn chat_completions( }; let mut response_as_tool = using_tools; while let Some(result) = response_stream.next().await { - if let Ok(stream_token) = result { + match result{ + Ok(stream_token) => { let token_text = &stream_token.token.text.clone(); match state { StreamState::Buffering => { @@ -1368,6 +1369,8 @@ pub(crate) async fn chat_completions( } } } + Err(err) => yield Ok(err.into_openai_event()) + } } yield Ok::(Event::default().data("[DONE]")); };