diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index 48ed01e2..8af0b95d 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -11,10 +11,30 @@ jobs: - name: Checkout code uses: actions/checkout@v2 + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + + - name: Install Protocol Buffers compiler + run: | + sudo apt-get update + sudo apt-get install -y protobuf-compiler libprotobuf-dev + - name: Install Launcher id: install-launcher run: cargo install --path launcher/ - - name: Check launcher Docs are up-to-date + + - name: Install router + id: install-router + run: cargo install --path router/ + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Check that documentation is up-to-date run: | - echo text-generation-launcher --help python update_doc.py --check diff --git a/docs/openapi.json b/docs/openapi.json index 79c3b80f..7dc159a8 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.0.1" + "version": "2.1.1-dev0" }, "paths": { "/": { @@ -19,7 +19,6 @@ "Text Generation Inference" ], "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", - "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", "operationId": "compat_generate", "requestBody": { "content": { @@ -108,7 +107,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "generate", "requestBody": { "content": { @@ -192,7 +190,6 @@ "Text Generation Inference" ], "summary": "Generate a stream of token using Server-Sent Events", - "description": "Generate a stream of token using Server-Sent Events", "operationId": "generate_stream", "requestBody": { "content": { @@ -276,7 +273,6 @@ "Text Generation Inference" ], "summary": "Health check method", - "description": "Health check method", "operationId": "health", "responses": { "200": { @@ -305,7 +301,6 @@ "Text Generation Inference" ], "summary": "Text Generation Inference endpoint info", - "description": "Text Generation Inference endpoint info", "operationId": "get_model_info", "responses": { "200": { @@ -327,7 +322,6 @@ "Text Generation Inference" ], "summary": "Prometheus metrics scrape endpoint", - "description": "Prometheus metrics scrape endpoint", "operationId": "metrics", "responses": { "200": { @@ -349,7 +343,6 @@ "Text Generation Inference" ], "summary": "Tokenize inputs", - "description": "Tokenize inputs", "operationId": "tokenize", "requestBody": { "content": { @@ -394,7 +387,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "chat_completions", "requestBody": { "content": { @@ -483,7 +475,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "completions", "requestBody": { "content": { @@ -626,7 +617,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -653,9 +643,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" }, @@ -697,7 +684,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -723,9 +709,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" } @@ -756,34 +739,19 @@ "nullable": true }, "message": { - "$ref": "#/components/schemas/Message" + "$ref": "#/components/schemas/OutputMessage" } } }, "ChatCompletionDelta": { - "type": "object", - "required": [ - "role" - ], - "properties": { - "content": { - "type": "string", - "example": "What is Deep Learning?", - "nullable": true + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" }, - "role": { - "type": "string", - "example": "user" - }, - "tool_calls": { - "allOf": [ - { - "$ref": "#/components/schemas/DeltaToolCall" - } - ], - "nullable": true + { + "$ref": "#/components/schemas/ToolCallDelta" } - } + ] }, "ChatCompletionLogprob": { "type": "object", @@ -903,6 +871,15 @@ "example": 0.1, "nullable": true }, + "response_format": { + "allOf": [ + { + "$ref": "#/components/schemas/GrammarType" + } + ], + "default": "null", + "nullable": true + }, "seed": { "type": "integer", "format": "int64", @@ -1021,7 +998,6 @@ "type": "object", "required": [ "id", - "object", "created", "choices", "model", @@ -1045,9 +1021,6 @@ "model": { "type": "string" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" } @@ -1081,12 +1054,7 @@ "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "prompt": { - "type": "array", - "items": { - "type": "string" - }, - "description": "The prompt to generate completions for.", - "example": "What is Deep Learning?" + "$ref": "#/components/schemas/Prompt" }, "repetition_penalty": { "type": "number", @@ -1100,6 +1068,15 @@ "nullable": true, "minimum": 0 }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + }, "stream": { "type": "boolean" }, @@ -1121,15 +1098,6 @@ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "example": 0.95, "nullable": true - }, - "stop": { - "type": "array", - "items": { - "type": "string" - }, - "description": "Up to 4 sequences where the API will stop generating further tokens.", - "example": "null", - "nullable": true } } }, @@ -1272,8 +1240,16 @@ "GenerateParameters": { "type": "object", "properties": { + "adapter_id": { + "type": "string", + "description": "Lora adapter id", + "default": "null", + "example": "null", + "nullable": true + }, "best_of": { "type": "integer", + "description": "Generate best_of sequences and return the one if the highest token logprobs.", "default": "null", "example": 1, "nullable": true, @@ -1282,20 +1258,24 @@ }, "decoder_input_details": { "type": "boolean", + "description": "Whether to return decoder input token logprobs and ids.", "default": "false" }, "details": { "type": "boolean", + "description": "Whether to return generation details.", "default": "true" }, "do_sample": { "type": "boolean", + "description": "Activate logits sampling.", "default": "false", "example": true }, "frequency_penalty": { "type": "number", "format": "float", + "description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", "default": "null", "example": 0.1, "nullable": true, @@ -1313,6 +1293,7 @@ "max_new_tokens": { "type": "integer", "format": "int32", + "description": "Maximum number of tokens to generate.", "default": "100", "example": "20", "nullable": true, @@ -1321,6 +1302,7 @@ "repetition_penalty": { "type": "number", "format": "float", + "description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.", "default": "null", "example": 1.03, "nullable": true, @@ -1328,6 +1310,7 @@ }, "return_full_text": { "type": "boolean", + "description": "Whether to prepend the prompt to the generated text", "default": "null", "example": false, "nullable": true @@ -1335,6 +1318,7 @@ "seed": { "type": "integer", "format": "int64", + "description": "Random sampling seed.", "default": "null", "example": "null", "nullable": true, @@ -1346,6 +1330,7 @@ "items": { "type": "string" }, + "description": "Stop generating tokens if a member of `stop` is generated.", "example": [ "photographer" ], @@ -1354,6 +1339,7 @@ "temperature": { "type": "number", "format": "float", + "description": "The value used to module the logits distribution.", "default": "null", "example": 0.5, "nullable": true, @@ -1362,6 +1348,7 @@ "top_k": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", "default": "null", "example": 10, "nullable": true, @@ -1370,6 +1357,7 @@ "top_n_tokens": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.", "default": "null", "example": 5, "nullable": true, @@ -1379,6 +1367,7 @@ "top_p": { "type": "number", "format": "float", + "description": "Top-p value for nucleus sampling.", "default": "null", "example": 0.95, "nullable": true, @@ -1387,6 +1376,7 @@ }, "truncate": { "type": "integer", + "description": "Truncate inputs tokens to the given size.", "default": "null", "example": "null", "nullable": true, @@ -1395,6 +1385,7 @@ "typical_p": { "type": "number", "format": "float", + "description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.", "default": "null", "example": 0.95, "nullable": true, @@ -1403,6 +1394,7 @@ }, "watermark": { "type": "boolean", + "description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).", "default": "false", "example": true } @@ -1495,13 +1487,14 @@ "max_concurrent_requests", "max_best_of", "max_stop_sequences", - "max_input_length", + "max_input_tokens", "max_total_tokens", "waiting_served_ratio", "max_batch_total_tokens", "max_waiting_tokens", "validation_workers", "max_client_batch_size", + "router", "version" ], "properties": { @@ -1538,7 +1531,7 @@ "example": "128", "minimum": 0 }, - "max_input_length": { + "max_input_tokens": { "type": "integer", "example": "1024", "minimum": 0 @@ -1581,6 +1574,11 @@ "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", "nullable": true }, + "router": { + "type": "string", + "description": "Router Info", + "example": "text-generation-router" + }, "sha": { "type": "string", "example": "null", @@ -1593,7 +1591,6 @@ }, "version": { "type": "string", - "description": "Router Info", "example": "0.5.0" }, "waiting_served_ratio": { @@ -1606,13 +1603,12 @@ "Message": { "type": "object", "required": [ - "role" + "role", + "content" ], "properties": { "content": { - "type": "string", - "example": "My name is David and I", - "nullable": true + "$ref": "#/components/schemas/MessageContent" }, "name": { "type": "string", @@ -1622,13 +1618,6 @@ "role": { "type": "string", "example": "user" - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolCall" - }, - "nullable": true } } }, @@ -1817,9 +1806,7 @@ "$ref": "#/components/schemas/FunctionDefinition" }, "id": { - "type": "integer", - "format": "int32", - "minimum": 0 + "type": "string" }, "type": { "type": "string" @@ -1830,20 +1817,22 @@ "oneOf": [ { "type": "object", - "required": [ - "FunctionName" - ], - "properties": { - "FunctionName": { - "type": "string" - } - } + "default": null, + "nullable": true }, { - "type": "string", - "enum": [ - "OneOf" - ] + "type": "string" + }, + { + "type": "object", + "required": [ + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionName" + } + } } ] }, diff --git a/router/src/main.rs b/router/src/main.rs index 8618f57e..21cd6649 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,6 @@ use axum::http::HeaderValue; use clap::Parser; +use clap::Subcommand; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use opentelemetry::sdk::propagation::TraceContextPropagator; @@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + #[command(subcommand)] + command: Option, + #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, #[clap(default_value = "2", long, env)] @@ -85,10 +89,15 @@ struct Args { max_client_batch_size: usize, } +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + #[tokio::main] async fn main() -> Result<(), RouterError> { - // Get args let args = Args::parse(); + // Pattern match configuration let Args { max_concurrent_requests, @@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + command, } = args; - // Launch Tokio runtime - init_logging(otlp_endpoint, otlp_service_name, json_output); + let print_schema_command = match command { + Some(Commands::PrintSchema) => true, + None => { + // only init logging if we are not running the print schema command + init_logging(otlp_endpoint, otlp_service_name, json_output); + false + } + }; // Validate args if max_input_tokens >= max_total_tokens { @@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + print_schema_command, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index d24774f9..9be6a35c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1387,10 +1387,10 @@ async fn tokenize( /// Prometheus metrics scrape endpoint #[utoipa::path( -get, -tag = "Text Generation Inference", -path = "/metrics", -responses((status = 200, description = "Prometheus Metrics", body = String)) + get, + tag = "Text Generation Inference", + path = "/metrics", + responses((status = 200, description = "Prometheus Metrics", body = String)) )] async fn metrics(prom_handle: Extension) -> String { prom_handle.render() @@ -1430,6 +1430,7 @@ pub async fn run( messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, + print_schema_command: bool, ) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1500,6 +1501,12 @@ pub async fn run( struct ApiDoc; // Create state + if print_schema_command { + let api_doc = ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + } // Open connection, get model info and warmup let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( diff --git a/update_doc.py b/update_doc.py index 5da81c72..1ff94a2c 100644 --- a/update_doc.py +++ b/update_doc.py @@ -1,6 +1,8 @@ import subprocess import argparse import ast +import json +import os TEMPLATE = """ # Supported Models and Hardware @@ -122,6 +124,53 @@ def check_supported_models(check: bool): f.write(final_doc) +def get_openapi_schema(): + try: + output = subprocess.check_output(["text-generation-router", "print-schema"]) + return json.loads(output) + except subprocess.CalledProcessError as e: + print(f"Error running text-generation-router print-schema: {e}") + raise SystemExit(1) + except json.JSONDecodeError: + print("Error: Invalid JSON received from text-generation-router print-schema") + raise SystemExit(1) + + +def check_openapi(check: bool): + new_openapi_data = get_openapi_schema() + filename = "docs/openapi.json" + tmp_filename = "openapi_tmp.json" + + with open(tmp_filename, "w") as f: + json.dump(new_openapi_data, f, indent=2) + + if check: + diff = subprocess.run( + [ + "diff", + # allow for trailing whitespace since it's not significant + # and the precommit hook will remove it + "--ignore-trailing-space", + tmp_filename, + filename, + ], + capture_output=True, + ).stdout.decode() + os.remove(tmp_filename) + + if diff: + print(diff) + raise Exception( + "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" + ) + + return True + else: + os.rename(tmp_filename, filename) + print("OpenAPI documentation updated.") + return True + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--check", action="store_true") @@ -130,6 +179,7 @@ def main(): check_cli(args.check) check_supported_models(args.check) + check_openapi(args.check) if __name__ == "__main__":