2023-09-27 08:01:38 -06:00
import subprocess
import argparse
2024-05-22 08:22:57 -06:00
import ast
2024-07-03 01:53:35 -06:00
import json
import os
2023-09-27 08:01:38 -06:00
2024-05-22 08:22:57 -06:00
TEMPLATE = """
# Supported Models and Hardware
2023-09-27 08:01:38 -06:00
2024-05-22 08:22:57 -06:00
Text Generation Inference enables serving optimized models on specific hardware for the highest performance . The following sections list which models are hardware are supported .
## Supported Models
SUPPORTED_MODELS
If the above list lacks the model you would like to serve , depending on the model ' s pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn ' t guaranteed for non - optimized models :
` ` ` python
# for causal LMs/text-generation models
AutoModelForCausalLM . from_pretrained ( < model > , device_map = " auto " ) `
# or, for text-to-text generation models
AutoModelForSeq2SeqLM . from_pretrained ( < model > , device_map = " auto " )
` ` `
If you wish to serve a supported model that already exists on a local folder , just point to the local folder .
` ` ` bash
text - generation - launcher - - model - id < PATH - TO - LOCAL - BLOOM >
` ` `
"""
2023-09-27 08:01:38 -06:00
2024-05-22 08:22:57 -06:00
def check_cli ( check : bool ) :
2023-09-28 01:55:47 -06:00
output = subprocess . check_output ( [ " text-generation-launcher " , " --help " ] ) . decode (
" utf-8 "
)
2023-10-04 04:57:21 -06:00
2023-09-28 09:30:36 -06:00
wrap_code_blocks_flag = " <!-- WRAP CODE BLOCKS --> "
2023-10-04 04:57:21 -06:00
final_doc = f " # Text-generation-launcher arguments \n \n { wrap_code_blocks_flag } \n \n "
lines = output . split ( " \n " )
header = " "
block = [ ]
for line in lines :
if line . startswith ( " - " ) or line . startswith ( " - " ) :
2023-12-11 06:49:52 -07:00
rendered_block = " \n " . join ( block )
2023-10-04 04:57:21 -06:00
if header :
final_doc + = f " ## { header } \n ```shell \n { rendered_block } \n ``` \n "
else :
final_doc + = f " ```shell \n { rendered_block } \n ``` \n "
block = [ ]
tokens = line . split ( " < " )
2023-12-11 06:49:52 -07:00
if len ( tokens ) > 1 :
2023-10-04 04:57:21 -06:00
header = tokens [ - 1 ] [ : - 1 ]
else :
header = line . split ( " -- " ) [ - 1 ]
header = header . upper ( ) . replace ( " - " , " _ " )
block . append ( line )
2023-12-11 06:49:52 -07:00
rendered_block = " \n " . join ( block )
2023-10-04 04:57:21 -06:00
final_doc + = f " ## { header } \n ```shell \n { rendered_block } \n ``` \n "
block = [ ]
2023-09-27 08:01:38 -06:00
filename = " docs/source/basic_tutorials/launcher.md "
2024-05-22 08:22:57 -06:00
if check :
2023-09-27 08:01:38 -06:00
with open ( filename , " r " ) as f :
doc = f . read ( )
if doc != final_doc :
tmp = " launcher.md "
with open ( tmp , " w " ) as g :
g . write ( final_doc )
2023-09-28 01:55:47 -06:00
diff = subprocess . run (
[ " diff " , tmp , filename ] , capture_output = True
) . stdout . decode ( " utf-8 " )
2023-09-27 08:01:38 -06:00
print ( diff )
2023-09-28 01:55:47 -06:00
raise Exception (
2024-05-22 08:22:57 -06:00
" Cli arguments Doc is not up-to-date, run `python update_doc.py` in order to update it "
2023-09-28 01:55:47 -06:00
)
2023-09-27 08:01:38 -06:00
else :
with open ( filename , " w " ) as f :
f . write ( final_doc )
2023-09-28 01:55:47 -06:00
2024-05-22 08:22:57 -06:00
def check_supported_models ( check : bool ) :
filename = " server/text_generation_server/models/__init__.py "
with open ( filename , " r " ) as f :
tree = ast . parse ( f . read ( ) )
enum_def = [
x for x in tree . body if isinstance ( x , ast . ClassDef ) and x . name == " ModelType "
] [ 0 ]
_locals = { }
_globals = { }
exec ( f " import enum \n { ast . unparse ( enum_def ) } " , _globals , _locals )
ModelType = _locals [ " ModelType " ]
list_string = " "
for data in ModelType :
list_string + = f " - [ { data . value [ ' name ' ] } ]( { data . value [ ' url ' ] } ) "
if data . value . get ( " multimodal " , None ) :
list_string + = " (Multimodal) "
list_string + = " \n "
final_doc = TEMPLATE . replace ( " SUPPORTED_MODELS " , list_string )
filename = " docs/source/supported_models.md "
if check :
with open ( filename , " r " ) as f :
doc = f . read ( )
if doc != final_doc :
tmp = " supported.md "
with open ( tmp , " w " ) as g :
g . write ( final_doc )
diff = subprocess . run (
[ " diff " , tmp , filename ] , capture_output = True
) . stdout . decode ( " utf-8 " )
print ( diff )
raise Exception (
" Supported models is not up-to-date, run `python update_doc.py` in order to update it "
)
else :
with open ( filename , " w " ) as f :
f . write ( final_doc )
2024-07-03 01:53:35 -06:00
def get_openapi_schema ( ) :
try :
output = subprocess . check_output ( [ " text-generation-router " , " print-schema " ] )
return json . loads ( output )
except subprocess . CalledProcessError as e :
print ( f " Error running text-generation-router print-schema: { e } " )
raise SystemExit ( 1 )
except json . JSONDecodeError :
print ( " Error: Invalid JSON received from text-generation-router print-schema " )
raise SystemExit ( 1 )
def check_openapi ( check : bool ) :
new_openapi_data = get_openapi_schema ( )
filename = " docs/openapi.json "
tmp_filename = " openapi_tmp.json "
with open ( tmp_filename , " w " ) as f :
json . dump ( new_openapi_data , f , indent = 2 )
if check :
diff = subprocess . run (
[
" diff " ,
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
" --ignore-trailing-space " ,
tmp_filename ,
filename ,
] ,
capture_output = True ,
) . stdout . decode ( )
os . remove ( tmp_filename )
if diff :
print ( diff )
raise Exception (
" OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it "
)
return True
else :
os . rename ( tmp_filename , filename )
print ( " OpenAPI documentation updated. " )
return True
2024-05-22 08:22:57 -06:00
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --check " , action = " store_true " )
args = parser . parse_args ( )
check_cli ( args . check )
check_supported_models ( args . check )
2024-07-03 01:53:35 -06:00
check_openapi ( args . check )
2024-05-22 08:22:57 -06:00
2023-09-27 08:01:38 -06:00
if __name__ == " __main__ " :
main ( )