From 0ac38d336a870bd4f09e18a9b62d64d78b032fbb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 8 Mar 2023 11:06:59 +0100 Subject: [PATCH] feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107) --- clients/python/README.md | 130 ++++++++++++++++++++- clients/python/pyproject.toml | 2 +- clients/python/text_generation/__init__.py | 2 +- clients/python/text_generation/client.py | 24 ++-- launcher/src/main.rs | 50 +++++++- 5 files changed, 190 insertions(+), 18 deletions(-) diff --git a/clients/python/README.md b/clients/python/README.md index 414360b..0f0b32f 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,7 +1,8 @@ # Text Generation The Hugging Face Text Generation Python library provides a convenient way of interfacing with a -`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub. +`text-generation-inference` instance running on +[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub. ## Get Started @@ -11,7 +12,7 @@ The Hugging Face Text Generation Python library provides a convenient way of int pip install text-generation ``` -### Usage +### Inference API Usage ```python from text_generation import InferenceAPIClient @@ -50,3 +51,128 @@ async for response in client.generate_stream("Why is the sky blue?"): print(text) # ' Rayleigh scattering' ``` + +### Hugging Fae Inference Endpoint usage + +```python +from text_generation import Client + +endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" + +client = Client(endpoint_url) +text = client.generate("Why is the sky blue?").generated_text +print(text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +or with the asynchronous client: + +```python +from text_generation import AsyncClient + +endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" + +client = AsyncClient(endpoint_url) +response = await client.generate("Why is the sky blue?") +print(response.generated_text) +# ' Rayleigh scattering' + +# Token Streaming +text = "" +async for response in client.generate_stream("Why is the sky blue?"): + if not response.token.special: + text += response.token.text + +print(text) +# ' Rayleigh scattering' +``` + +### Types + +```python +# Prompt tokens +class PrefillToken: + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + # Optional since the logprob of the first token cannot be computed + logprob: Optional[float] + + +# Generated tokens +class Token: + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + logprob: float + # Is the token a special token + # Can be used to ignore tokens when concatenating + special: bool + + +# Generation finish reason +class FinishReason(Enum): + # number of generated tokens == `max_new_tokens` + Length = "length" + # the model generated its end of sequence token + EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` + StopSequence = "stop_sequence" + + +# `generate` details +class Details: + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + # Prompt tokens + prefill: List[PrefillToken] + # Generated tokens + tokens: List[Token] + + +# `generate` return value +class Response: + # Generated text + generated_text: str + # Generation details + details: Details + + +# `generate_stream` details +class StreamDetails: + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] + + +# `generate_stream` return value +class StreamResponse: + # Generated token + token: Token + # Complete generated text + # Only available when the generation is finished + generated_text: Optional[str] + # Generation details + # Only available when the generation is finished + details: Optional[StreamDetails] +``` \ No newline at end of file diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 6eb1163..4c7b24e 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.1.0" +version = "0.2.0" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index 88861b3..1f5e6ce 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.2" +__version__ = "0.2.0" from text_generation.client import Client, AsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index d0a6791..3e9bbc3 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -63,7 +63,7 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> Response: """ Given a prompt, generate the following text @@ -91,7 +91,7 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -109,7 +109,7 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -136,7 +136,7 @@ class Client: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -164,7 +164,7 @@ class Client: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -182,7 +182,7 @@ class Client: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -268,7 +268,7 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -296,7 +296,7 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -314,7 +314,7 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -338,7 +338,7 @@ class AsyncClient: temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, - watermarking: bool = False, + watermark: bool = False, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -366,7 +366,7 @@ class AsyncClient: top_p (`float`): If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. - watermarking (`bool`): + watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Returns: @@ -384,7 +384,7 @@ class AsyncClient: temperature=temperature, top_k=top_k, top_p=top_p, - watermark=watermarking, + watermark=watermark, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 66dcb2d..96ad18f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -23,8 +23,10 @@ struct Args { model_id: String, #[clap(long, env)] revision: Option, - #[clap(default_value = "1", long, env)] - num_shard: usize, + #[clap(long, env)] + sharded: Option, + #[clap(long, env)] + num_shard: Option, #[clap(long, env)] quantize: bool, #[clap(default_value = "128", long, env)] @@ -80,6 +82,7 @@ fn main() -> ExitCode { let Args { model_id, revision, + sharded, num_shard, quantize, max_concurrent_requests, @@ -102,6 +105,49 @@ fn main() -> ExitCode { watermark_delta, } = args; + // get the number of shards given `sharded` and `num_shard` + let num_shard = if let Some(sharded) = sharded { + // sharded is set + match sharded { + // sharded is set and true + true => { + match num_shard { + None => { + // try to default to the number of available GPUs + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); + let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES") + .expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); + let n_devices = cuda_visible_devices.split(",").count(); + if n_devices <= 1 { + panic!("`sharded` is true but only found {n_devices} CUDA devices"); + } + tracing::info!("Sharding on {n_devices} found CUDA devices"); + n_devices + } + Some(num_shard) => { + // we can't have only one shard while sharded + if num_shard <= 1 { + panic!("`sharded` is true but `num_shard` <= 1"); + } + num_shard + } + } + } + // sharded is set and false + false => { + let num_shard = num_shard.unwrap_or(1); + // we can't have more than one shard while not sharded + if num_shard != 1 { + panic!("`sharded` is false but `num_shard` != 1"); + } + num_shard + } + } + } else { + // default to a single shard + num_shard.unwrap_or(1) + }; + // Signal handler let running = Arc::new(AtomicBool::new(true)); let r = running.clone();