fix(docker): increase shm size (#60)

This commit is contained in:
OlivierDehaene 2023-02-08 17:53:33 +01:00 committed by GitHub
parent c503a639b1
commit 1ad3250b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 16 deletions

View File

@ -30,9 +30,7 @@ ENV LANG=C.UTF-8 \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
NUM_SHARD=1 \
SAFETENSORS_FAST_GPU=1 \
PORT=80 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \

View File

@ -25,8 +25,9 @@ to power LLMs api-inference widgets.
- [Officially Supported Models](#officially-supported-models)
- [Get Started](#get-started)
- [Docker](#docker)
- [API Documentation](#api-documentation)
- [A note on Shared Memory](#a-note-on-shared-memory-shm)
- [Local Install](#local-install)
- [OpenAPI](#api-documentation)
- [CUDA Kernels](#cuda-kernels)
- [Run BLOOM](#run-bloom)
- [Download](#download)
@ -54,7 +55,7 @@ to power LLMs api-inference widgets.
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl): use `--revision pr/26`
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
Other models are supported on a best effort basis using:
@ -75,7 +76,7 @@ model=bigscience/bloom-560m
num_shard=2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
```
You can then query the model using either the `/generate` or `/generate_stream` routes:
@ -101,6 +102,32 @@ curl 127.0.0.1:8080/generate_stream \
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
`PyTorch` to do distributed training/inference. `text-generation-inference` make
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
peer-to-peer using NVLink or PCI is not possible.
To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
creating a volume with:
```yaml
- name: shm
emptyDir:
medium: Memory
sizeLimit: 1Gi
```
and mounting it to `/dev/shm`.
Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
this will impact performance.
### Local install
You can also opt to install `text-generation-inference` locally.
@ -122,10 +149,10 @@ BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork
make run-bloom-560m
```
**Note:** on some machines, you may also need the OpenSSL libraries. On Linux machines, run:
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
```shell
sudo apt-get install libssl-dev
sudo apt-get install libssl-dev gcc -y
```
### CUDA Kernels

View File

@ -38,9 +38,9 @@ struct Args {
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
shard_uds_path: String,
#[clap(default_value = "0.0.0.0", long, env)]
#[clap(default_value = "localhost", long, env)]
master_addr: String,
#[clap(default_value = "6000", long, env)]
#[clap(default_value = "29500", long, env)]
master_port: usize,
#[clap(long, env)]
json_output: bool,
@ -305,6 +305,7 @@ fn shard_manager(
("MASTER_ADDR".into(), master_addr.into()),
("MASTER_PORT".into(), master_port.to_string().into()),
("SAFETENSORS_FAST_GPU".into(), "1".into()),
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
@ -322,6 +323,12 @@ fn shard_manager(
));
};
// If the NCCL_SHM_DISABLE env var is set, pass it to the shard
// needed when running NCCL inside a docker container and when you can't increase shm size
if let Ok(nccl_shm_disalbe) = env::var("NCCL_SHM_DISABLE") {
env.push(("NCCL_SHM_DISABLE".into(), nccl_shm_disalbe.into()));
};
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));

View File

@ -162,29 +162,29 @@ def initialize_torch_distributed():
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
# initialized `torch.distributed`
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
master_ip = os.getenv("MASTER_ADDR", "0.0.0.0")
master_port = os.getenv("MASTER_PORT", "6000")
init_method = f"tcp://{master_ip}:{master_port}"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options
)
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
return torch.distributed.group.WORLD, rank, world_size
def weight_hub_files(model_id, revision=None, extension=".safetensors"):