feat: improve update_docs for openapi schema (#2169)

* feat: add pre commit step to force schema update when router changes

* fix: prefer improved update_doc and start server and compare

* fix: adjust typo

* fix: adjust revert typo

* fix: update workflow to use update_doc md command

* feat: improve workflow to check openapi schema too

* fix: adjust timeout for CI

* fix: adjust raise condition and install server in ci

* fix: install protoc before server

* feat: improve update doc and add command to print router schema

* fix: adjust autodoc workflow

* fix: explicitly install protoc and python

* fix: alllow trailing space in openapi schema diff
This commit is contained in:
drbh 2024-07-03 03:53:35 -04:00 committed by GitHub
parent 0759ec495e
commit 571530dd9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 180 additions and 97 deletions

View File

@ -11,10 +11,30 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
- name: Install Protocol Buffers compiler
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev
- name: Install Launcher - name: Install Launcher
id: install-launcher id: install-launcher
run: cargo install --path launcher/ run: cargo install --path launcher/
- name: Check launcher Docs are up-to-date
- name: Install router
id: install-router
run: cargo install --path router/
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Check that documentation is up-to-date
run: | run: |
echo text-generation-launcher --help
python update_doc.py --check python update_doc.py --check

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "2.0.1" "version": "2.1.1-dev0"
}, },
"paths": { "paths": {
"/": { "/": {
@ -19,7 +19,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"operationId": "compat_generate", "operationId": "compat_generate",
"requestBody": { "requestBody": {
"content": { "content": {
@ -108,7 +107,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens", "summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "generate", "operationId": "generate",
"requestBody": { "requestBody": {
"content": { "content": {
@ -192,7 +190,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate a stream of token using Server-Sent Events", "summary": "Generate a stream of token using Server-Sent Events",
"description": "Generate a stream of token using Server-Sent Events",
"operationId": "generate_stream", "operationId": "generate_stream",
"requestBody": { "requestBody": {
"content": { "content": {
@ -276,7 +273,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Health check method", "summary": "Health check method",
"description": "Health check method",
"operationId": "health", "operationId": "health",
"responses": { "responses": {
"200": { "200": {
@ -305,7 +301,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Text Generation Inference endpoint info", "summary": "Text Generation Inference endpoint info",
"description": "Text Generation Inference endpoint info",
"operationId": "get_model_info", "operationId": "get_model_info",
"responses": { "responses": {
"200": { "200": {
@ -327,7 +322,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Prometheus metrics scrape endpoint", "summary": "Prometheus metrics scrape endpoint",
"description": "Prometheus metrics scrape endpoint",
"operationId": "metrics", "operationId": "metrics",
"responses": { "responses": {
"200": { "200": {
@ -349,7 +343,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Tokenize inputs", "summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize", "operationId": "tokenize",
"requestBody": { "requestBody": {
"content": { "content": {
@ -394,7 +387,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens", "summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "chat_completions", "operationId": "chat_completions",
"requestBody": { "requestBody": {
"content": { "content": {
@ -483,7 +475,6 @@
"Text Generation Inference" "Text Generation Inference"
], ],
"summary": "Generate tokens", "summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "completions", "operationId": "completions",
"requestBody": { "requestBody": {
"content": { "content": {
@ -626,7 +617,6 @@
"type": "object", "type": "object",
"required": [ "required": [
"id", "id",
"object",
"created", "created",
"model", "model",
"system_fingerprint", "system_fingerprint",
@ -653,9 +643,6 @@
"type": "string", "type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"object": {
"type": "string"
},
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
}, },
@ -697,7 +684,6 @@
"type": "object", "type": "object",
"required": [ "required": [
"id", "id",
"object",
"created", "created",
"model", "model",
"system_fingerprint", "system_fingerprint",
@ -723,9 +709,6 @@
"type": "string", "type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"object": {
"type": "string"
},
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
} }
@ -756,34 +739,19 @@
"nullable": true "nullable": true
}, },
"message": { "message": {
"$ref": "#/components/schemas/Message" "$ref": "#/components/schemas/OutputMessage"
} }
} }
}, },
"ChatCompletionDelta": { "ChatCompletionDelta": {
"type": "object", "oneOf": [
"required": [ {
"role" "$ref": "#/components/schemas/TextMessage"
],
"properties": {
"content": {
"type": "string",
"example": "What is Deep Learning?",
"nullable": true
}, },
"role": { {
"type": "string", "$ref": "#/components/schemas/ToolCallDelta"
"example": "user"
},
"tool_calls": {
"allOf": [
{
"$ref": "#/components/schemas/DeltaToolCall"
}
],
"nullable": true
} }
} ]
}, },
"ChatCompletionLogprob": { "ChatCompletionLogprob": {
"type": "object", "type": "object",
@ -903,6 +871,15 @@
"example": 0.1, "example": 0.1,
"nullable": true "nullable": true
}, },
"response_format": {
"allOf": [
{
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"nullable": true
},
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
@ -1021,7 +998,6 @@
"type": "object", "type": "object",
"required": [ "required": [
"id", "id",
"object",
"created", "created",
"choices", "choices",
"model", "model",
@ -1045,9 +1021,6 @@
"model": { "model": {
"type": "string" "type": "string"
}, },
"object": {
"type": "string"
},
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
} }
@ -1081,12 +1054,7 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2" "example": "mistralai/Mistral-7B-Instruct-v0.2"
}, },
"prompt": { "prompt": {
"type": "array", "$ref": "#/components/schemas/Prompt"
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.",
"example": "What is Deep Learning?"
}, },
"repetition_penalty": { "repetition_penalty": {
"type": "number", "type": "number",
@ -1100,6 +1068,15 @@
"nullable": true, "nullable": true,
"minimum": 0 "minimum": 0
}, },
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
},
"stream": { "stream": {
"type": "boolean" "type": "boolean"
}, },
@ -1121,15 +1098,6 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95, "example": 0.95,
"nullable": true "nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
} }
} }
}, },
@ -1272,8 +1240,16 @@
"GenerateParameters": { "GenerateParameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"adapter_id": {
"type": "string",
"description": "Lora adapter id",
"default": "null",
"example": "null",
"nullable": true
},
"best_of": { "best_of": {
"type": "integer", "type": "integer",
"description": "Generate best_of sequences and return the one if the highest token logprobs.",
"default": "null", "default": "null",
"example": 1, "example": 1,
"nullable": true, "nullable": true,
@ -1282,20 +1258,24 @@
}, },
"decoder_input_details": { "decoder_input_details": {
"type": "boolean", "type": "boolean",
"description": "Whether to return decoder input token logprobs and ids.",
"default": "false" "default": "false"
}, },
"details": { "details": {
"type": "boolean", "type": "boolean",
"description": "Whether to return generation details.",
"default": "true" "default": "true"
}, },
"do_sample": { "do_sample": {
"type": "boolean", "type": "boolean",
"description": "Activate logits sampling.",
"default": "false", "default": "false",
"example": true "example": true
}, },
"frequency_penalty": { "frequency_penalty": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
"default": "null", "default": "null",
"example": 0.1, "example": 0.1,
"nullable": true, "nullable": true,
@ -1313,6 +1293,7 @@
"max_new_tokens": { "max_new_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "Maximum number of tokens to generate.",
"default": "100", "default": "100",
"example": "20", "example": "20",
"nullable": true, "nullable": true,
@ -1321,6 +1302,7 @@
"repetition_penalty": { "repetition_penalty": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.",
"default": "null", "default": "null",
"example": 1.03, "example": 1.03,
"nullable": true, "nullable": true,
@ -1328,6 +1310,7 @@
}, },
"return_full_text": { "return_full_text": {
"type": "boolean", "type": "boolean",
"description": "Whether to prepend the prompt to the generated text",
"default": "null", "default": "null",
"example": false, "example": false,
"nullable": true "nullable": true
@ -1335,6 +1318,7 @@
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"description": "Random sampling seed.",
"default": "null", "default": "null",
"example": "null", "example": "null",
"nullable": true, "nullable": true,
@ -1346,6 +1330,7 @@
"items": { "items": {
"type": "string" "type": "string"
}, },
"description": "Stop generating tokens if a member of `stop` is generated.",
"example": [ "example": [
"photographer" "photographer"
], ],
@ -1354,6 +1339,7 @@
"temperature": { "temperature": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "The value used to module the logits distribution.",
"default": "null", "default": "null",
"example": 0.5, "example": 0.5,
"nullable": true, "nullable": true,
@ -1362,6 +1348,7 @@
"top_k": { "top_k": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
"default": "null", "default": "null",
"example": 10, "example": 10,
"nullable": true, "nullable": true,
@ -1370,6 +1357,7 @@
"top_n_tokens": { "top_n_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.",
"default": "null", "default": "null",
"example": 5, "example": 5,
"nullable": true, "nullable": true,
@ -1379,6 +1367,7 @@
"top_p": { "top_p": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "Top-p value for nucleus sampling.",
"default": "null", "default": "null",
"example": 0.95, "example": 0.95,
"nullable": true, "nullable": true,
@ -1387,6 +1376,7 @@
}, },
"truncate": { "truncate": {
"type": "integer", "type": "integer",
"description": "Truncate inputs tokens to the given size.",
"default": "null", "default": "null",
"example": "null", "example": "null",
"nullable": true, "nullable": true,
@ -1395,6 +1385,7 @@
"typical_p": { "typical_p": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.",
"default": "null", "default": "null",
"example": 0.95, "example": 0.95,
"nullable": true, "nullable": true,
@ -1403,6 +1394,7 @@
}, },
"watermark": { "watermark": {
"type": "boolean", "type": "boolean",
"description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).",
"default": "false", "default": "false",
"example": true "example": true
} }
@ -1495,13 +1487,14 @@
"max_concurrent_requests", "max_concurrent_requests",
"max_best_of", "max_best_of",
"max_stop_sequences", "max_stop_sequences",
"max_input_length", "max_input_tokens",
"max_total_tokens", "max_total_tokens",
"waiting_served_ratio", "waiting_served_ratio",
"max_batch_total_tokens", "max_batch_total_tokens",
"max_waiting_tokens", "max_waiting_tokens",
"validation_workers", "validation_workers",
"max_client_batch_size", "max_client_batch_size",
"router",
"version" "version"
], ],
"properties": { "properties": {
@ -1538,7 +1531,7 @@
"example": "128", "example": "128",
"minimum": 0 "minimum": 0
}, },
"max_input_length": { "max_input_tokens": {
"type": "integer", "type": "integer",
"example": "1024", "example": "1024",
"minimum": 0 "minimum": 0
@ -1581,6 +1574,11 @@
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2", "example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
"nullable": true "nullable": true
}, },
"router": {
"type": "string",
"description": "Router Info",
"example": "text-generation-router"
},
"sha": { "sha": {
"type": "string", "type": "string",
"example": "null", "example": "null",
@ -1593,7 +1591,6 @@
}, },
"version": { "version": {
"type": "string", "type": "string",
"description": "Router Info",
"example": "0.5.0" "example": "0.5.0"
}, },
"waiting_served_ratio": { "waiting_served_ratio": {
@ -1606,13 +1603,12 @@
"Message": { "Message": {
"type": "object", "type": "object",
"required": [ "required": [
"role" "role",
"content"
], ],
"properties": { "properties": {
"content": { "content": {
"type": "string", "$ref": "#/components/schemas/MessageContent"
"example": "My name is David and I",
"nullable": true
}, },
"name": { "name": {
"type": "string", "type": "string",
@ -1622,13 +1618,6 @@
"role": { "role": {
"type": "string", "type": "string",
"example": "user" "example": "user"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
},
"nullable": true
} }
} }
}, },
@ -1817,9 +1806,7 @@
"$ref": "#/components/schemas/FunctionDefinition" "$ref": "#/components/schemas/FunctionDefinition"
}, },
"id": { "id": {
"type": "integer", "type": "string"
"format": "int32",
"minimum": 0
}, },
"type": { "type": {
"type": "string" "type": "string"
@ -1830,20 +1817,22 @@
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "object",
"required": [ "default": null,
"FunctionName" "nullable": true
],
"properties": {
"FunctionName": {
"type": "string"
}
}
}, },
{ {
"type": "string", "type": "string"
"enum": [ },
"OneOf" {
] "type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
} }
] ]
}, },

View File

@ -1,5 +1,6 @@
use axum::http::HeaderValue; use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType}; use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
@ -85,10 +89,15 @@ struct Args {
max_client_batch_size: usize, max_client_batch_size: usize,
} }
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), RouterError> { async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests, max_concurrent_requests,
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
command,
} = args; } = args;
// Launch Tokio runtime let print_schema_command = match command {
init_logging(otlp_endpoint, otlp_service_name, json_output); Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args // Validate args
if max_input_tokens >= max_total_tokens { if max_input_tokens >= max_total_tokens {
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
print_schema_command,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -1387,10 +1387,10 @@ async fn tokenize(
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/metrics", path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String)) responses((status = 200, description = "Prometheus Metrics", body = String))
)] )]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String { async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
@ -1430,6 +1430,7 @@ pub async fn run(
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
print_schema_command: bool,
) -> Result<(), WebServerError> { ) -> Result<(), WebServerError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -1500,6 +1501,12 @@ pub async fn run(
struct ApiDoc; struct ApiDoc;
// Create state // Create state
if print_schema_command {
let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
}
// Open connection, get model info and warmup // Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( let (scheduler, health_ext, shard_info, max_batch_total_tokens): (

View File

@ -1,6 +1,8 @@
import subprocess import subprocess
import argparse import argparse
import ast import ast
import json
import os
TEMPLATE = """ TEMPLATE = """
# Supported Models and Hardware # Supported Models and Hardware
@ -122,6 +124,53 @@ def check_supported_models(check: bool):
f.write(final_doc) f.write(final_doc)
def get_openapi_schema():
try:
output = subprocess.check_output(["text-generation-router", "print-schema"])
return json.loads(output)
except subprocess.CalledProcessError as e:
print(f"Error running text-generation-router print-schema: {e}")
raise SystemExit(1)
except json.JSONDecodeError:
print("Error: Invalid JSON received from text-generation-router print-schema")
raise SystemExit(1)
def check_openapi(check: bool):
new_openapi_data = get_openapi_schema()
filename = "docs/openapi.json"
tmp_filename = "openapi_tmp.json"
with open(tmp_filename, "w") as f:
json.dump(new_openapi_data, f, indent=2)
if check:
diff = subprocess.run(
[
"diff",
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"--ignore-trailing-space",
tmp_filename,
filename,
],
capture_output=True,
).stdout.decode()
os.remove(tmp_filename)
if diff:
print(diff)
raise Exception(
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
)
return True
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
return True
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true") parser.add_argument("--check", action="store_true")
@ -130,6 +179,7 @@ def main():
check_cli(args.check) check_cli(args.check)
check_supported_models(args.check) check_supported_models(args.check)
check_openapi(args.check)
if __name__ == "__main__": if __name__ == "__main__":