feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)
This commit is contained in:
parent
b1485e18c5
commit
0ac38d336a
|
@ -1,7 +1,8 @@
|
||||||
# Text Generation
|
# Text Generation
|
||||||
|
|
||||||
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||||
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub.
|
`text-generation-inference` instance running on
|
||||||
|
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
|
||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
|
@ -11,7 +12,7 @@ The Hugging Face Text Generation Python library provides a convenient way of int
|
||||||
pip install text-generation
|
pip install text-generation
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage
|
### Inference API Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from text_generation import InferenceAPIClient
|
from text_generation import InferenceAPIClient
|
||||||
|
@ -50,3 +51,128 @@ async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
print(text)
|
print(text)
|
||||||
# ' Rayleigh scattering'
|
# ' Rayleigh scattering'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Hugging Fae Inference Endpoint usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import Client
|
||||||
|
|
||||||
|
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
|
||||||
|
|
||||||
|
client = Client(endpoint_url)
|
||||||
|
text = client.generate("Why is the sky blue?").generated_text
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
or with the asynchronous client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
|
||||||
|
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
|
||||||
|
|
||||||
|
client = AsyncClient(endpoint_url)
|
||||||
|
response = await client.generate("Why is the sky blue?")
|
||||||
|
print(response.generated_text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Types
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Prompt tokens
|
||||||
|
class PrefillToken:
|
||||||
|
# Token ID from the model tokenizer
|
||||||
|
id: int
|
||||||
|
# Token text
|
||||||
|
text: str
|
||||||
|
# Logprob
|
||||||
|
# Optional since the logprob of the first token cannot be computed
|
||||||
|
logprob: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
# Generated tokens
|
||||||
|
class Token:
|
||||||
|
# Token ID from the model tokenizer
|
||||||
|
id: int
|
||||||
|
# Token text
|
||||||
|
text: str
|
||||||
|
# Logprob
|
||||||
|
logprob: float
|
||||||
|
# Is the token a special token
|
||||||
|
# Can be used to ignore tokens when concatenating
|
||||||
|
special: bool
|
||||||
|
|
||||||
|
|
||||||
|
# Generation finish reason
|
||||||
|
class FinishReason(Enum):
|
||||||
|
# number of generated tokens == `max_new_tokens`
|
||||||
|
Length = "length"
|
||||||
|
# the model generated its end of sequence token
|
||||||
|
EndOfSequenceToken = "eos_token"
|
||||||
|
# the model generated a text included in `stop_sequences`
|
||||||
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` details
|
||||||
|
class Details:
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` return value
|
||||||
|
class Response:
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation details
|
||||||
|
details: Details
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` details
|
||||||
|
class StreamDetails:
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` return value
|
||||||
|
class StreamResponse:
|
||||||
|
# Generated token
|
||||||
|
token: Token
|
||||||
|
# Complete generated text
|
||||||
|
# Only available when the generation is finished
|
||||||
|
generated_text: Optional[str]
|
||||||
|
# Generation details
|
||||||
|
# Only available when the generation is finished
|
||||||
|
details: Optional[StreamDetails]
|
||||||
|
```
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
__version__ = "0.3.2"
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
from text_generation.client import Client, AsyncClient
|
from text_generation.client import Client, AsyncClient
|
||||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Client:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
watermarking: bool = False,
|
watermark: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text
|
Given a prompt, generate the following text
|
||||||
|
@ -91,7 +91,7 @@ class Client:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
watermarking (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -109,7 +109,7 @@ class Client:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
watermark=watermarking,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class Client:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
watermarking: bool = False,
|
watermark: bool = False,
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens
|
Given a prompt, generate the following stream of tokens
|
||||||
|
@ -164,7 +164,7 @@ class Client:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
watermarking (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -182,7 +182,7 @@ class Client:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
watermark=watermarking,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ class AsyncClient:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
watermarking: bool = False,
|
watermark: bool = False,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
|
@ -296,7 +296,7 @@ class AsyncClient:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
watermarking (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -314,7 +314,7 @@ class AsyncClient:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
watermark=watermarking,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
@ -338,7 +338,7 @@ class AsyncClient:
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
watermarking: bool = False,
|
watermark: bool = False,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens asynchronously
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
|
@ -366,7 +366,7 @@ class AsyncClient:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
higher are kept for generation.
|
higher are kept for generation.
|
||||||
watermarking (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -384,7 +384,7 @@ class AsyncClient:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
watermark=watermarking,
|
watermark=watermark,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,10 @@ struct Args {
|
||||||
model_id: String,
|
model_id: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(default_value = "1", long, env)]
|
#[clap(long, env)]
|
||||||
num_shard: usize,
|
sharded: Option<bool>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
num_shard: Option<usize>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
|
@ -80,6 +82,7 @@ fn main() -> ExitCode {
|
||||||
let Args {
|
let Args {
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
sharded,
|
||||||
num_shard,
|
num_shard,
|
||||||
quantize,
|
quantize,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
@ -102,6 +105,49 @@ fn main() -> ExitCode {
|
||||||
watermark_delta,
|
watermark_delta,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
// get the number of shards given `sharded` and `num_shard`
|
||||||
|
let num_shard = if let Some(sharded) = sharded {
|
||||||
|
// sharded is set
|
||||||
|
match sharded {
|
||||||
|
// sharded is set and true
|
||||||
|
true => {
|
||||||
|
match num_shard {
|
||||||
|
None => {
|
||||||
|
// try to default to the number of available GPUs
|
||||||
|
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||||
|
let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES")
|
||||||
|
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||||
|
let n_devices = cuda_visible_devices.split(",").count();
|
||||||
|
if n_devices <= 1 {
|
||||||
|
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||||
|
}
|
||||||
|
tracing::info!("Sharding on {n_devices} found CUDA devices");
|
||||||
|
n_devices
|
||||||
|
}
|
||||||
|
Some(num_shard) => {
|
||||||
|
// we can't have only one shard while sharded
|
||||||
|
if num_shard <= 1 {
|
||||||
|
panic!("`sharded` is true but `num_shard` <= 1");
|
||||||
|
}
|
||||||
|
num_shard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// sharded is set and false
|
||||||
|
false => {
|
||||||
|
let num_shard = num_shard.unwrap_or(1);
|
||||||
|
// we can't have more than one shard while not sharded
|
||||||
|
if num_shard != 1 {
|
||||||
|
panic!("`sharded` is false but `num_shard` != 1");
|
||||||
|
}
|
||||||
|
num_shard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// default to a single shard
|
||||||
|
num_shard.unwrap_or(1)
|
||||||
|
};
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
let r = running.clone();
|
let r = running.clone();
|
||||||
|
|
Loading…
Reference in New Issue