diff --git a/docs/openapi.json b/docs/openapi.json index 7dc159a8..b2002efe 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -946,6 +946,38 @@ } } }, + "Chunk": { + "type": "object", + "required": [ + "id", + "created", + "choices", + "model", + "system_fingerprint" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, "CompatGenerateRequest": { "type": "object", "required": [ @@ -965,6 +997,55 @@ } } }, + "Completion": { + "oneOf": [ + { + "allOf": [ + { + "$ref": "#/components/schemas/Chunk" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + }, + { + "allOf": [ + { + "$ref": "#/components/schemas/CompletionFinal" + }, + { + "type": "object", + "required": [ + "object" + ], + "properties": { + "object": { + "type": "string", + "enum": [ + "text_completion" + ] + } + } + } + ] + } + ], + "discriminator": { + "propertyName": "object" + } + }, "CompletionComplete": { "type": "object", "required": [ @@ -994,14 +1075,15 @@ } } }, - "CompletionCompleteChunk": { + "CompletionFinal": { "type": "object", "required": [ "id", "created", - "choices", "model", - "system_fingerprint" + "system_fingerprint", + "choices", + "usage" ], "properties": { "choices": { @@ -1013,16 +1095,21 @@ "created": { "type": "integer", "format": "int64", + "example": "1706270835", "minimum": 0 }, "id": { "type": "string" }, "model": { - "type": "string" + "type": "string", + "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "system_fingerprint": { "type": "string" + }, + "usage": { + "$ref": "#/components/schemas/Usage" } } }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 9ecfa051..165b2ad2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -433,8 +433,17 @@ pub struct CompletionRequest { pub stop: Option>, } +#[derive(Clone, Serialize, ToSchema)] +#[serde(tag = "object")] +enum Completion { + #[serde(rename = "text_completion")] + Chunk(Chunk), + #[serde(rename = "text_completion")] + Final(CompletionFinal), +} + #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] -pub(crate) struct Completion { +pub(crate) struct CompletionFinal { pub id: String, #[schema(example = "1706270835")] pub created: u64, @@ -453,6 +462,15 @@ pub(crate) struct CompletionComplete { pub finish_reason: String, } +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct Chunk { + pub id: String, + pub created: u64, + pub choices: Vec, + pub model: String, + pub system_fingerprint: String, +} + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, @@ -614,15 +632,6 @@ impl ChatCompletion { } } } -#[derive(Clone, Deserialize, Serialize, ToSchema)] -pub(crate) struct CompletionCompleteChunk { - pub id: String, - pub created: u64, - pub choices: Vec, - pub model: String, - pub system_fingerprint: String, -} - #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index 9be6a35c..ce0d6675 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,7 +19,7 @@ use crate::{ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, - ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, + ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; @@ -705,7 +705,7 @@ async fn completions( .as_secs(); event - .json_data(CompletionCompleteChunk { + .json_data(Completion::Chunk(Chunk { id: "".to_string(), created: current_time, @@ -718,7 +718,7 @@ async fn completions( model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - }) + })) .unwrap_or_else(|_e| Event::default()) }; @@ -931,7 +931,7 @@ async fn completions( .collect::, _>>() .map_err(|(status, Json(err))| (status, Json(err)))?; - let response = Completion { + let response = Completion::Final(CompletionFinal { id: "".to_string(), created: current_time, model: info.model_id.clone(), @@ -946,7 +946,7 @@ async fn completions( completion_tokens, total_tokens, }, - }; + }); // headers similar to `generate` but aggregated let mut headers = HeaderMap::new(); @@ -1464,7 +1464,9 @@ pub async fn run( ChatCompletion, CompletionRequest, CompletionComplete, - CompletionCompleteChunk, + Chunk, + Completion, + CompletionFinal, GenerateParameters, PrefillToken, Token,