delete the last no repeat processor from warpers
This commit is contained in:
parent
e29fc9e32a
commit
12381b0b0e
|
@ -18,7 +18,6 @@ from transformers import (
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
NoRepeatNGramLogitsProcessor
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
@ -44,8 +43,6 @@ class StaticWarper:
|
||||||
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
self.warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
if typical_p is not None and typical_p < 1.0:
|
if typical_p is not None and typical_p < 1.0:
|
||||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
|
||||||
self.warpers.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
|
||||||
|
|
||||||
self.cuda_graph = None
|
self.cuda_graph = None
|
||||||
self.static_scores = None
|
self.static_scores = None
|
||||||
|
|
Loading…
Reference in New Issue