chore(docker): use nvidia base image (#318)
This commit is contained in:
parent
76a48cd365
commit
8a8f43410d
13
Dockerfile
13
Dockerfile
|
@ -108,7 +108,7 @@ COPY server/Makefile-transformers Makefile
|
||||||
RUN BUILD_EXTENSIONS="True" make build-transformers
|
RUN BUILD_EXTENSIONS="True" make build-transformers
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM debian:bullseye-slim as base
|
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base
|
||||||
|
|
||||||
# Conda env
|
# Conda env
|
||||||
ENV PATH=/opt/conda/bin:$PATH \
|
ENV PATH=/opt/conda/bin:$PATH \
|
||||||
|
@ -122,17 +122,6 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
NUM_SHARD=1 \
|
NUM_SHARD=1 \
|
||||||
PORT=80
|
PORT=80
|
||||||
|
|
||||||
# NVIDIA env vars
|
|
||||||
ENV NVIDIA_VISIBLE_DEVICES all
|
|
||||||
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
|
|
||||||
# Required for nvidia-docker v1
|
|
||||||
RUN /bin/bash -c echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
|
|
||||||
echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
|
|
||||||
ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
|
||||||
ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
|
|
||||||
|
|
||||||
LABEL com.nvidia.volumes.needed="nvidia_driver"
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
|
|
@ -585,8 +585,20 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
|
|
||||||
if self.transformer.tp_embeddings:
|
if self.transformer.tp_embeddings:
|
||||||
# Logits are sharded, so we need to gather them
|
# Logits are sharded, so we need to gather them
|
||||||
|
if logits.shape[0] == 1:
|
||||||
|
# Fast path when batch size is 1
|
||||||
|
world_logits = logits.new_empty(
|
||||||
|
(logits.shape[1] * self.transformer.tp_world_size)
|
||||||
|
)
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_logits, logits.view(-1), group=self.transformer.process_group
|
||||||
|
)
|
||||||
|
world_logits = world_logits.view(1, -1)
|
||||||
|
else:
|
||||||
|
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
|
||||||
world_logits = [
|
world_logits = [
|
||||||
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
|
torch.empty_like(logits)
|
||||||
|
for _ in range(self.transformer.tp_world_size)
|
||||||
]
|
]
|
||||||
torch.distributed.all_gather(
|
torch.distributed.all_gather(
|
||||||
world_logits, logits, group=self.transformer.process_group
|
world_logits, logits, group=self.transformer.process_group
|
||||||
|
|
|
@ -217,6 +217,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
decode_buffer=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in New Issue