fix warping

This commit is contained in:
OlivierDehaene 2023-05-24 19:37:42 +02:00
parent a86e4bf713
commit a794c677ae
2 changed files with 7 additions and 4 deletions

View File

@ -80,8 +80,8 @@ jobs:
latest=auto
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
# ghcr.io/huggingface/text-generation-inference
# db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
@ -93,7 +93,8 @@ jobs:
with:
context: .
file: Dockerfile
push: ${{ github.event_name != 'pull_request' }}
# push: ${{ github.event_name != 'pull_request' }}
push: true
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}

View File

@ -67,9 +67,11 @@ class StaticWarper:
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph):
local_scores = self.static_scores
for warper in self.warpers:
self.static_warped_scores = warper(None, self.static_scores)
local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores
# Compute logprobs
self.static_next_logprob = torch.log_softmax(
self.static_warped_scores, -1