chore(docker): use nvidia base image (#318)

This commit is contained in:
OlivierDehaene 2023-05-12 17:32:40 +02:00 committed by GitHub
parent 76a48cd365
commit 8a8f43410d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 19 deletions

View File

@ -108,7 +108,7 @@ COPY server/Makefile-transformers Makefile
RUN BUILD_EXTENSIONS="True" make build-transformers
# Text Generation Inference base image
FROM debian:bullseye-slim as base
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base
# Conda env
ENV PATH=/opt/conda/bin:$PATH \
@ -122,17 +122,6 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
NUM_SHARD=1 \
PORT=80
# NVIDIA env vars
ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
# Required for nvidia-docker v1
RUN /bin/bash -c echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
LABEL com.nvidia.volumes.needed="nvidia_driver"
WORKDIR /usr/src
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \

View File

@ -585,8 +585,20 @@ class FlashSantacoderForCausalLM(nn.Module):
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
if logits.shape[0] == 1:
# Fast path when batch size is 1
world_logits = logits.new_empty(
(logits.shape[1] * self.transformer.tp_world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits.view(-1), group=self.transformer.process_group
)
world_logits = world_logits.view(1, -1)
else:
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
world_logits = [
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
torch.empty_like(logits)
for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group

View File

@ -217,6 +217,7 @@ class FlashSantacoderSharded(FlashSantacoder):
device=device,
rank=rank,
world_size=world_size,
decode_buffer=1,
)
@staticmethod