parent
ece7ffa40a
commit
53aa9194c8
|
@ -237,20 +237,12 @@ def get_model(
|
|||
)
|
||||
|
||||
elif model_type == "t5":
|
||||
if sharded:
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return Seq2SeqLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
|
|
|
@ -42,25 +42,31 @@ class StaticWarper:
|
|||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
if torch.cuda.is_available():
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
local_scores = warper(None, local_scores)
|
||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
local_scores = warper(None, local_scores)
|
||||
|
||||
self.static_warped_scores = local_scores
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
self.static_warped_scores = local_scores
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
# CPU branch
|
||||
for warper in self.warpers:
|
||||
scores = warper(None, scores)
|
||||
return scores, torch.log_softmax(scores, -1)
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
|
|
Loading…
Reference in New Issue