fix(server): fix seeding on gpu (#42)
This commit is contained in:
parent
4f9ac67cfa
commit
03bdf18290
|
@ -63,7 +63,7 @@ class CausalLMBatch(Batch):
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -102,7 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -73,7 +73,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criterias.append(
|
stopping_criterias.append(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,8 +24,8 @@ from text_generation.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
def __init__(self, seed: Optional[int] = None):
|
def __init__(self, seed: Optional[int] = None, device: str = "cpu"):
|
||||||
self.generator = torch.Generator()
|
self.generator = torch.Generator(device)
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
self.generator.manual_seed(seed)
|
self.generator.manual_seed(seed)
|
||||||
else:
|
else:
|
||||||
|
@ -50,7 +50,13 @@ class Greedy:
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
|
self,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
do_sample=False,
|
||||||
|
seed=None,
|
||||||
|
device="cpu",
|
||||||
):
|
):
|
||||||
warpers = LogitsProcessorList()
|
warpers = LogitsProcessorList()
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
@ -68,7 +74,7 @@ class NextTokenChooser:
|
||||||
sampling = True
|
sampling = True
|
||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
self.choice = Sampling(seed) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
# Warp logits
|
||||||
|
@ -80,7 +86,9 @@ class NextTokenChooser:
|
||||||
return next_ids, logprobs
|
return next_ids, logprobs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
|
def from_pb(
|
||||||
|
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
||||||
|
) -> "NextTokenChooser":
|
||||||
# handle protobuf making default values 0
|
# handle protobuf making default values 0
|
||||||
seed = pb.seed if pb.HasField("seed") else None
|
seed = pb.seed if pb.HasField("seed") else None
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
|
@ -89,6 +97,7 @@ class NextTokenChooser:
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
device=str(device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue