chore(docker): use nvidia base image (#318)
This commit is contained in:
parent
76a48cd365
commit
8a8f43410d
13
Dockerfile
13
Dockerfile
|
@ -108,7 +108,7 @@ COPY server/Makefile-transformers Makefile
|
|||
RUN BUILD_EXTENSIONS="True" make build-transformers
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM debian:bullseye-slim as base
|
||||
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base
|
||||
|
||||
# Conda env
|
||||
ENV PATH=/opt/conda/bin:$PATH \
|
||||
|
@ -122,17 +122,6 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||
NUM_SHARD=1 \
|
||||
PORT=80
|
||||
|
||||
# NVIDIA env vars
|
||||
ENV NVIDIA_VISIBLE_DEVICES all
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
|
||||
# Required for nvidia-docker v1
|
||||
RUN /bin/bash -c echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
|
||||
echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
|
||||
ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH
|
||||
|
||||
LABEL com.nvidia.volumes.needed="nvidia_driver"
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
|
|
|
@ -585,13 +585,25 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
|
||||
if self.transformer.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [
|
||||
torch.empty_like(logits) for _ in range(self.transformer.tp_world_size)
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
world_logits, logits, group=self.transformer.process_group
|
||||
)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
if logits.shape[0] == 1:
|
||||
# Fast path when batch size is 1
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[1] * self.transformer.tp_world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits.view(-1), group=self.transformer.process_group
|
||||
)
|
||||
world_logits = world_logits.view(1, -1)
|
||||
else:
|
||||
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
|
||||
world_logits = [
|
||||
torch.empty_like(logits)
|
||||
for _ in range(self.transformer.tp_world_size)
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
world_logits, logits, group=self.transformer.process_group
|
||||
)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
|
||||
|
|
|
@ -217,6 +217,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
decode_buffer=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue