fix(server): fix seeding on gpu (#42)

This commit is contained in:
OlivierDehaene 2023-01-31 14:30:33 +01:00 committed by GitHub
parent 4f9ac67cfa
commit 03bdf18290
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 8 deletions

View File

@ -63,7 +63,7 @@ class CausalLMBatch(Batch):
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )

View File

@ -102,7 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )

View File

@ -73,7 +73,7 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token # Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
) )

View File

@ -24,8 +24,8 @@ from text_generation.pb import generate_pb2
class Sampling: class Sampling:
def __init__(self, seed: Optional[int] = None): def __init__(self, seed: Optional[int] = None, device: str = "cpu"):
self.generator = torch.Generator() self.generator = torch.Generator(device)
if seed is not None: if seed is not None:
self.generator.manual_seed(seed) self.generator.manual_seed(seed)
else: else:
@ -50,7 +50,13 @@ class Greedy:
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None self,
temperature=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=None,
device="cpu",
): ):
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
@ -68,7 +74,7 @@ class NextTokenChooser:
sampling = True sampling = True
self.warpers = warpers self.warpers = warpers
self.choice = Sampling(seed) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
@ -80,7 +86,9 @@ class NextTokenChooser:
return next_ids, logprobs return next_ids, logprobs
@classmethod @classmethod
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
# handle protobuf making default values 0 # handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser( return NextTokenChooser(
@ -89,6 +97,7 @@ class NextTokenChooser:
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=seed, seed=seed,
device=str(device),
) )