feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)

This commit is contained in:
OlivierDehaene 2023-03-08 11:06:59 +01:00 committed by GitHub
parent b1485e18c5
commit 0ac38d336a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 190 additions and 18 deletions

View File

@ -1,7 +1,8 @@
# Text Generation # Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub. `text-generation-inference` instance running on
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
## Get Started ## Get Started
@ -11,7 +12,7 @@ The Hugging Face Text Generation Python library provides a convenient way of int
pip install text-generation pip install text-generation
``` ```
### Usage ### Inference API Usage
```python ```python
from text_generation import InferenceAPIClient from text_generation import InferenceAPIClient
@ -50,3 +51,128 @@ async for response in client.generate_stream("Why is the sky blue?"):
print(text) print(text)
# ' Rayleigh scattering' # ' Rayleigh scattering'
``` ```
### Hugging Fae Inference Endpoint usage
```python
from text_generation import Client
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import AsyncClient
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
### Types
```python
# Prompt tokens
class PrefillToken:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# `generate` details
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` return value
class Response:
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse:
# Generated token
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]
```

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.1.0" version = "0.2.0"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.3.2" __version__ = "0.2.0"
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -63,7 +63,7 @@ class Client:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
watermarking: bool = False, watermark: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -91,7 +91,7 @@ class Client:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
watermarking (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns: Returns:
@ -109,7 +109,7 @@ class Client:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
watermark=watermarking, watermark=watermark,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -136,7 +136,7 @@ class Client:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
watermarking: bool = False, watermark: bool = False,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -164,7 +164,7 @@ class Client:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
watermarking (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns: Returns:
@ -182,7 +182,7 @@ class Client:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
watermark=watermarking, watermark=watermark,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -268,7 +268,7 @@ class AsyncClient:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
watermarking: bool = False, watermark: bool = False,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -296,7 +296,7 @@ class AsyncClient:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
watermarking (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns: Returns:
@ -314,7 +314,7 @@ class AsyncClient:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
watermark=watermarking, watermark=watermark,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -338,7 +338,7 @@ class AsyncClient:
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
watermarking: bool = False, watermark: bool = False,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -366,7 +366,7 @@ class AsyncClient:
top_p (`float`): top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation. higher are kept for generation.
watermarking (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns: Returns:
@ -384,7 +384,7 @@ class AsyncClient:
temperature=temperature, temperature=temperature,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
watermark=watermarking, watermark=watermark,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -23,8 +23,10 @@ struct Args {
model_id: String, model_id: String,
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
#[clap(default_value = "1", long, env)] #[clap(long, env)]
num_shard: usize, sharded: Option<bool>,
#[clap(long, env)]
num_shard: Option<usize>,
#[clap(long, env)] #[clap(long, env)]
quantize: bool, quantize: bool,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
@ -80,6 +82,7 @@ fn main() -> ExitCode {
let Args { let Args {
model_id, model_id,
revision, revision,
sharded,
num_shard, num_shard,
quantize, quantize,
max_concurrent_requests, max_concurrent_requests,
@ -102,6 +105,49 @@ fn main() -> ExitCode {
watermark_delta, watermark_delta,
} = args; } = args;
// get the number of shards given `sharded` and `num_shard`
let num_shard = if let Some(sharded) = sharded {
// sharded is set
match sharded {
// sharded is set and true
true => {
match num_shard {
None => {
// try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES")
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
let n_devices = cuda_visible_devices.split(",").count();
if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices");
}
tracing::info!("Sharding on {n_devices} found CUDA devices");
n_devices
}
Some(num_shard) => {
// we can't have only one shard while sharded
if num_shard <= 1 {
panic!("`sharded` is true but `num_shard` <= 1");
}
num_shard
}
}
}
// sharded is set and false
false => {
let num_shard = num_shard.unwrap_or(1);
// we can't have more than one shard while not sharded
if num_shard != 1 {
panic!("`sharded` is false but `num_shard` != 1");
}
num_shard
}
}
} else {
// default to a single shard
num_shard.unwrap_or(1)
};
// Signal handler // Signal handler
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
let r = running.clone(); let r = running.clone();