hf_text-generation-inference/router/src/lib.rs

205 lines
5.2 KiB
Rust
Raw Normal View History

/// Text Generation Inference Webserver
mod infer;
mod queue;
2022-10-17 10:27:33 -06:00
pub mod server;
2022-10-18 07:19:03 -06:00
mod validation;
2022-10-17 06:59:00 -06:00
use infer::Infer;
use queue::{Entry, Queue};
2022-10-18 07:19:03 -06:00
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
2022-10-17 06:59:00 -06:00
use validation::Validation;
2022-10-18 07:19:03 -06:00
#[derive(Clone, Debug, Deserialize, ToSchema)]
2022-10-18 07:19:03 -06:00
pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 0.5
)]
pub temperature: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 1.03
)]
pub repetition_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub top_p: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)]
2022-10-18 07:19:03 -06:00
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
2022-10-18 07:19:03 -06:00
pub max_new_tokens: u32,
2022-12-15 09:03:56 -07:00
#[serde(default)]
#[schema(default = "null", example = false)]
pub return_full_text: Option<bool>,
#[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
2022-12-12 10:25:22 -07:00
pub stop: Vec<String>,
2022-12-15 09:03:56 -07:00
#[serde(default)]
#[schema(default = "null", example = "null")]
pub truncate: Option<usize>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub watermark: bool,
#[serde(default)]
#[schema(default = "true")]
2022-12-15 09:03:56 -07:00
pub details: bool,
#[serde(default)]
pub seed: Option<u64>,
2022-10-18 07:19:03 -06:00
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
2022-10-18 07:19:03 -06:00
max_new_tokens: default_max_new_tokens(),
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
2022-12-15 09:03:56 -07:00
details: false,
seed: None,
2022-10-18 07:19:03 -06:00
}
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
2022-10-18 07:19:03 -06:00
pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")]
2022-10-18 07:19:03 -06:00
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
#[serde(default)]
#[allow(dead_code)]
pub stream: bool,
}
impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
parameters: req.parameters,
}
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct Token {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
#[schema(example = "false")]
special: bool,
}
#[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length,
#[serde(rename = "eos_token")]
#[schema(rename = "eos_token")]
EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence,
}
#[derive(Serialize, ToSchema)]
2022-12-15 09:03:56 -07:00
pub(crate) struct Details {
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
2022-12-15 09:03:56 -07:00
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
2023-03-07 10:52:22 -07:00
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
2022-12-15 09:03:56 -07:00
}
#[derive(Serialize, ToSchema)]
pub(crate) struct GenerateResponse {
#[schema(example = "test")]
2022-10-18 07:19:03 -06:00
pub generated_text: String,
2022-12-15 09:03:56 -07:00
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>,
2022-10-18 07:19:03 -06:00
}
2022-10-27 06:25:29 -06:00
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamDetails {
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]
pub details: Option<StreamDetails>,
}
#[derive(Serialize, ToSchema)]
2022-10-27 06:25:29 -06:00
pub(crate) struct ErrorResponse {
pub error: String,
2023-03-07 10:52:22 -07:00
pub error_type: String,
2022-10-27 06:25:29 -06:00
}