diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c3477eef..e9fb96b0 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -25,10 +25,10 @@ class Sampling: def __call__(self, logits): probs = torch.nn.functional.softmax(logits, -1) + # Avoid GPU<->CPU sync done by torch multinomial # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 - q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs) - - return q.argmax() + q = torch.empty_like(probs).exponential_(1, generator=self.generator) + return probs.div_(q).argmax() class Greedy: