fix warping
This commit is contained in:
parent
a86e4bf713
commit
a794c677ae
|
@ -80,8 +80,8 @@ jobs:
|
|||
latest=auto
|
||||
images: |
|
||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||
ghcr.io/huggingface/text-generation-inference
|
||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||
# ghcr.io/huggingface/text-generation-inference
|
||||
# db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
|
@ -93,7 +93,8 @@ jobs:
|
|||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
# push: ${{ github.event_name != 'pull_request' }}
|
||||
push: true
|
||||
platforms: 'linux/amd64'
|
||||
build-args: |
|
||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||
|
|
|
@ -67,9 +67,11 @@ class StaticWarper:
|
|||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
self.static_warped_scores = warper(None, self.static_scores)
|
||||
local_scores = warper(None, local_scores)
|
||||
|
||||
self.static_warped_scores = local_scores
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
|
|
Loading…
Reference in New Issue