feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)
This commit is contained in:
parent
b1485e18c5
commit
0ac38d336a
|
@ -1,7 +1,8 @@
|
|||
# Text Generation
|
||||
|
||||
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||
`text-generation-inference` instance running on your own infrastructure or on the Hugging Face Hub.
|
||||
`text-generation-inference` instance running on
|
||||
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
|
||||
|
||||
## Get Started
|
||||
|
||||
|
@ -11,7 +12,7 @@ The Hugging Face Text Generation Python library provides a convenient way of int
|
|||
pip install text-generation
|
||||
```
|
||||
|
||||
### Usage
|
||||
### Inference API Usage
|
||||
|
||||
```python
|
||||
from text_generation import InferenceAPIClient
|
||||
|
@ -50,3 +51,128 @@ async for response in client.generate_stream("Why is the sky blue?"):
|
|||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
```
|
||||
|
||||
### Hugging Fae Inference Endpoint usage
|
||||
|
||||
```python
|
||||
from text_generation import Client
|
||||
|
||||
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
|
||||
|
||||
client = Client(endpoint_url)
|
||||
text = client.generate("Why is the sky blue?").generated_text
|
||||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
|
||||
# Token Streaming
|
||||
text = ""
|
||||
for response in client.generate_stream("Why is the sky blue?"):
|
||||
if not response.token.special:
|
||||
text += response.token.text
|
||||
|
||||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
```
|
||||
|
||||
or with the asynchronous client:
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
|
||||
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
|
||||
|
||||
client = AsyncClient(endpoint_url)
|
||||
response = await client.generate("Why is the sky blue?")
|
||||
print(response.generated_text)
|
||||
# ' Rayleigh scattering'
|
||||
|
||||
# Token Streaming
|
||||
text = ""
|
||||
async for response in client.generate_stream("Why is the sky blue?"):
|
||||
if not response.token.special:
|
||||
text += response.token.text
|
||||
|
||||
print(text)
|
||||
# ' Rayleigh scattering'
|
||||
```
|
||||
|
||||
### Types
|
||||
|
||||
```python
|
||||
# Prompt tokens
|
||||
class PrefillToken:
|
||||
# Token ID from the model tokenizer
|
||||
id: int
|
||||
# Token text
|
||||
text: str
|
||||
# Logprob
|
||||
# Optional since the logprob of the first token cannot be computed
|
||||
logprob: Optional[float]
|
||||
|
||||
|
||||
# Generated tokens
|
||||
class Token:
|
||||
# Token ID from the model tokenizer
|
||||
id: int
|
||||
# Token text
|
||||
text: str
|
||||
# Logprob
|
||||
logprob: float
|
||||
# Is the token a special token
|
||||
# Can be used to ignore tokens when concatenating
|
||||
special: bool
|
||||
|
||||
|
||||
# Generation finish reason
|
||||
class FinishReason(Enum):
|
||||
# number of generated tokens == `max_new_tokens`
|
||||
Length = "length"
|
||||
# the model generated its end of sequence token
|
||||
EndOfSequenceToken = "eos_token"
|
||||
# the model generated a text included in `stop_sequences`
|
||||
StopSequence = "stop_sequence"
|
||||
|
||||
|
||||
# `generate` details
|
||||
class Details:
|
||||
# Generation finish reason
|
||||
finish_reason: FinishReason
|
||||
# Number of generated tokens
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
# Prompt tokens
|
||||
prefill: List[PrefillToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
|
||||
|
||||
# `generate` return value
|
||||
class Response:
|
||||
# Generated text
|
||||
generated_text: str
|
||||
# Generation details
|
||||
details: Details
|
||||
|
||||
|
||||
# `generate_stream` details
|
||||
class StreamDetails:
|
||||
# Generation finish reason
|
||||
finish_reason: FinishReason
|
||||
# Number of generated tokens
|
||||
generated_tokens: int
|
||||
# Sampling seed if sampling was activated
|
||||
seed: Optional[int]
|
||||
|
||||
|
||||
# `generate_stream` return value
|
||||
class StreamResponse:
|
||||
# Generated token
|
||||
token: Token
|
||||
# Complete generated text
|
||||
# Only available when the generation is finished
|
||||
generated_text: Optional[str]
|
||||
# Generation details
|
||||
# Only available when the generation is finished
|
||||
details: Optional[StreamDetails]
|
||||
```
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "Hugging Face Text Generation Python Client"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.3.2"
|
||||
__version__ = "0.2.0"
|
||||
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
||||
|
|
|
@ -63,7 +63,7 @@ class Client:
|
|||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
watermark: bool = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text
|
||||
|
@ -91,7 +91,7 @@ class Client:
|
|||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
|
@ -109,7 +109,7 @@ class Client:
|
|||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
|
@ -136,7 +136,7 @@ class Client:
|
|||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
watermark: bool = False,
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens
|
||||
|
@ -164,7 +164,7 @@ class Client:
|
|||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
|
@ -182,7 +182,7 @@ class Client:
|
|||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
@ -268,7 +268,7 @@ class AsyncClient:
|
|||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
watermark: bool = False,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
|
@ -296,7 +296,7 @@ class AsyncClient:
|
|||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
|
@ -314,7 +314,7 @@ class AsyncClient:
|
|||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
|
@ -338,7 +338,7 @@ class AsyncClient:
|
|||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermarking: bool = False,
|
||||
watermark: bool = False,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
|
@ -366,7 +366,7 @@ class AsyncClient:
|
|||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
watermarking (`bool`):
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
|
||||
Returns:
|
||||
|
@ -384,7 +384,7 @@ class AsyncClient:
|
|||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermarking,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
|
|
@ -23,8 +23,10 @@ struct Args {
|
|||
model_id: String,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "1", long, env)]
|
||||
num_shard: usize,
|
||||
#[clap(long, env)]
|
||||
sharded: Option<bool>,
|
||||
#[clap(long, env)]
|
||||
num_shard: Option<usize>,
|
||||
#[clap(long, env)]
|
||||
quantize: bool,
|
||||
#[clap(default_value = "128", long, env)]
|
||||
|
@ -80,6 +82,7 @@ fn main() -> ExitCode {
|
|||
let Args {
|
||||
model_id,
|
||||
revision,
|
||||
sharded,
|
||||
num_shard,
|
||||
quantize,
|
||||
max_concurrent_requests,
|
||||
|
@ -102,6 +105,49 @@ fn main() -> ExitCode {
|
|||
watermark_delta,
|
||||
} = args;
|
||||
|
||||
// get the number of shards given `sharded` and `num_shard`
|
||||
let num_shard = if let Some(sharded) = sharded {
|
||||
// sharded is set
|
||||
match sharded {
|
||||
// sharded is set and true
|
||||
true => {
|
||||
match num_shard {
|
||||
None => {
|
||||
// try to default to the number of available GPUs
|
||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||
let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES")
|
||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||
let n_devices = cuda_visible_devices.split(",").count();
|
||||
if n_devices <= 1 {
|
||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||
}
|
||||
tracing::info!("Sharding on {n_devices} found CUDA devices");
|
||||
n_devices
|
||||
}
|
||||
Some(num_shard) => {
|
||||
// we can't have only one shard while sharded
|
||||
if num_shard <= 1 {
|
||||
panic!("`sharded` is true but `num_shard` <= 1");
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
}
|
||||
}
|
||||
// sharded is set and false
|
||||
false => {
|
||||
let num_shard = num_shard.unwrap_or(1);
|
||||
// we can't have more than one shard while not sharded
|
||||
if num_shard != 1 {
|
||||
panic!("`sharded` is false but `num_shard` != 1");
|
||||
}
|
||||
num_shard
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// default to a single shard
|
||||
num_shard.unwrap_or(1)
|
||||
};
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
|
|
Loading…
Reference in New Issue