diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 0c1467fe..a532c9ec 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -1,9 +1,6 @@ use crate::infer::Infer; use crate::server::{generate_internal, ComputeType}; -use crate::{ - ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message, - StreamOptions, Tool, ToolChoice, -}; +use crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest}; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; @@ -21,162 +18,12 @@ pub(crate) struct GenerateVertexInstance { pub parameters: Option, } -#[derive(Clone, Deserialize, ToSchema)] -#[cfg_attr(test, derive(Debug, PartialEq))] -pub(crate) struct VertexChat { - messages: Vec, - // Messages is ignored there. - #[serde(default)] - parameters: VertexParameters, -} - -#[derive(Clone, Deserialize, ToSchema, Serialize, Default)] -#[cfg_attr(test, derive(Debug, PartialEq))] -pub(crate) struct VertexParameters { - #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] - /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, - /// decreasing the model's likelihood to repeat the same line verbatim. - #[serde(default)] - #[schema(example = "1.0")] - pub frequency_penalty: Option, - - /// UNUSED - /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens - /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, - /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, - /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should - /// result in a ban or exclusive selection of the relevant token. - #[serde(default)] - pub logit_bias: Option>, - - /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each - /// output token returned in the content of message. - #[serde(default)] - #[schema(example = "false")] - pub logprobs: Option, - - /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with - /// an associated log probability. logprobs must be set to true if this parameter is used. - #[serde(default)] - #[schema(example = "5")] - pub top_logprobs: Option, - - /// The maximum number of tokens that can be generated in the chat completion. - #[serde(default)] - #[schema(example = "32")] - pub max_tokens: Option, - - /// UNUSED - /// How many chat completion choices to generate for each input message. Note that you will be charged based on the - /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. - #[serde(default)] - #[schema(nullable = true, example = "2")] - pub n: Option, - - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, - /// increasing the model's likelihood to talk about new topics - #[serde(default)] - #[schema(nullable = true, example = 0.1)] - pub presence_penalty: Option, - - /// Up to 4 sequences where the API will stop generating further tokens. - #[serde(default)] - #[schema(nullable = true, example = "null")] - pub stop: Option>, - - #[serde(default = "bool::default")] - pub stream: bool, - - #[schema(nullable = true, example = 42)] - pub seed: Option, - - /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while - /// lower values like 0.2 will make it more focused and deterministic. - /// - /// We generally recommend altering this or `top_p` but not both. - #[serde(default)] - #[schema(nullable = true, example = 1.0)] - pub temperature: Option, - - /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the - /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - #[serde(default)] - #[schema(nullable = true, example = 0.95)] - pub top_p: Option, - - /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of - /// functions the model may generate JSON inputs for. - #[serde(default)] - #[schema(nullable = true, example = "null")] - pub tools: Option>, - - /// A prompt to be appended before the tools - #[serde(default)] - #[schema( - nullable = true, - example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." - )] - pub tool_prompt: Option, - - /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. - #[serde(default)] - #[schema(nullable = true, example = "null")] - pub tool_choice: ToolChoice, - - /// Response format constraints for the generation. - /// - /// NOTE: A request can use `response_format` OR `tools` but not both. - #[serde(default)] - #[schema(nullable = true, default = "null", example = "null")] - pub response_format: Option, - - /// A guideline to be used in the chat_template - #[serde(default)] - #[schema(nullable = true, default = "null", example = "null")] - pub guideline: Option, - - /// Options for streaming response. Only set this when you set stream: true. - #[serde(default)] - #[schema(nullable = true, example = "null")] - pub stream_options: Option, -} - -impl From for ChatRequest { - fn from(val: VertexChat) -> Self { - Self { - messages: val.messages, - frequency_penalty: val.parameters.frequency_penalty, - guideline: val.parameters.guideline, - logit_bias: val.parameters.logit_bias, - logprobs: val.parameters.logprobs, - max_tokens: val.parameters.max_tokens, - model: val.parameters.model, - n: val.parameters.n, - presence_penalty: val.parameters.presence_penalty, - response_format: val.parameters.response_format, - seed: val.parameters.seed, - stop: val.parameters.stop, - stream_options: val.parameters.stream_options, - stream: val.parameters.stream, - temperature: val.parameters.temperature, - tool_choice: val.parameters.tool_choice, - tool_prompt: val.parameters.tool_prompt, - tools: val.parameters.tools, - top_logprobs: val.parameters.top_logprobs, - top_p: val.parameters.top_p, - } - } -} - #[derive(Clone, Deserialize, ToSchema)] #[cfg_attr(test, derive(Debug, PartialEq))] #[serde(untagged)] pub(crate) enum VertexInstance { Generate(GenerateVertexInstance), - Chat(VertexChat), + Chat(ChatRequest), } #[derive(Deserialize, ToSchema)] @@ -257,9 +104,8 @@ pub(crate) async fn vertex_compatibility( }, }, VertexInstance::Chat(instance) => { - let chat_request: ChatRequest = instance.into(); let (generate_request, _using_tools): (GenerateRequest, bool) = - chat_request.try_into_generate(&infer)?; + instance.try_into_generate(&infer)?; generate_request } }; @@ -305,34 +151,14 @@ mod tests { #[test] fn vertex_deserialization() { - let string = serde_json::json!({ - - "messages": [{"role": "user", "content": "What's Deep Learning?"}], - "parameters": { - "max_tokens": 128, - "top_p": 0.95, - "temperature": 0.7 - } - }); - - let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize"); - - let string = serde_json::json!({ - "messages": [{"role": "user", "content": "What's Deep Learning?"}], - }); - - let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize"); - let string = serde_json::json!({ "instances": [ { "messages": [{"role": "user", "content": "What's Deep Learning?"}], - "parameters": { - "max_tokens": 128, - "top_p": 0.95, - "temperature": 0.7 - } + "max_tokens": 128, + "top_p": 0.95, + "temperature": 0.7 } ] @@ -341,18 +167,16 @@ mod tests { assert_eq!( request, VertexRequest { - instances: vec![VertexInstance::Chat(VertexChat { + instances: vec![VertexInstance::Chat(ChatRequest { messages: vec![Message { role: "user".to_string(), content: MessageContent::SingleText("What's Deep Learning?".to_string()), name: None, },], - parameters: VertexParameters { - max_tokens: Some(128), - top_p: Some(0.95), - temperature: Some(0.7), - ..Default::default() - } + max_tokens: Some(128), + top_p: Some(0.95), + temperature: Some(0.7), + ..Default::default() })] } );