fix(server): fix multinomial implem in Sampling
This commit is contained in:
parent
a6c18c39bb
commit
4f6d038c0b
|
@ -25,10 +25,10 @@ class Sampling:
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
probs = torch.nn.functional.softmax(logits, -1)
|
probs = torch.nn.functional.softmax(logits, -1)
|
||||||
|
# Avoid GPU<->CPU sync done by torch multinomial
|
||||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||||
q = torch.empty_like(probs).exponential_(1, generator=self.generator).div_(probs)
|
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
|
||||||
|
return probs.div_(q).argmax()
|
||||||
return q.argmax()
|
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
|
|
Loading…
Reference in New Issue