fix(server): fix seeding on gpu (#42)
This commit is contained in:
parent
4f9ac67cfa
commit
03bdf18290
|
@ -63,7 +63,7 @@ class CausalLMBatch(Batch):
|
|||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
|
|
@ -102,7 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
|
|
@ -73,7 +73,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
# Decoder sequence only contains the bos_token
|
||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
|
|
|
@ -24,8 +24,8 @@ from text_generation.pb import generate_pb2
|
|||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: Optional[int] = None):
|
||||
self.generator = torch.Generator()
|
||||
def __init__(self, seed: Optional[int] = None, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
if seed is not None:
|
||||
self.generator.manual_seed(seed)
|
||||
else:
|
||||
|
@ -50,7 +50,13 @@ class Greedy:
|
|||
|
||||
class NextTokenChooser:
|
||||
def __init__(
|
||||
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
|
||||
self,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
do_sample=False,
|
||||
seed=None,
|
||||
device="cpu",
|
||||
):
|
||||
warpers = LogitsProcessorList()
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
|
@ -68,7 +74,7 @@ class NextTokenChooser:
|
|||
sampling = True
|
||||
|
||||
self.warpers = warpers
|
||||
self.choice = Sampling(seed) if sampling else Greedy()
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
|
@ -80,7 +86,9 @@ class NextTokenChooser:
|
|||
return next_ids, logprobs
|
||||
|
||||
@classmethod
|
||||
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
||||
) -> "NextTokenChooser":
|
||||
# handle protobuf making default values 0
|
||||
seed = pb.seed if pb.HasField("seed") else None
|
||||
return NextTokenChooser(
|
||||
|
@ -89,6 +97,7 @@ class NextTokenChooser:
|
|||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
seed=seed,
|
||||
device=str(device),
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue