diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index ccd4c3ba..bf49d134 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -63,7 +63,7 @@ class CausalLMBatch(Batch): for r in pb.requests: inputs.append(r.inputs) input_lengths.append(r.input_length) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index d047ccb6..9bec1dde 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -102,7 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) input_lengths.append(r.input_length) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index f965ea88..26ebc7d7 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -73,7 +73,7 @@ class Seq2SeqLMBatch(Batch): # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 1d087a42..c93e783b 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -24,8 +24,8 @@ from text_generation.pb import generate_pb2 class Sampling: - def __init__(self, seed: Optional[int] = None): - self.generator = torch.Generator() + def __init__(self, seed: Optional[int] = None, device: str = "cpu"): + self.generator = torch.Generator(device) if seed is not None: self.generator.manual_seed(seed) else: @@ -50,7 +50,13 @@ class Greedy: class NextTokenChooser: def __init__( - self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None + self, + temperature=1.0, + top_k=None, + top_p=None, + do_sample=False, + seed=None, + device="cpu", ): warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files @@ -68,7 +74,7 @@ class NextTokenChooser: sampling = True self.warpers = warpers - self.choice = Sampling(seed) if sampling else Greedy() + self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): # Warp logits @@ -80,7 +86,9 @@ class NextTokenChooser: return next_ids, logprobs @classmethod - def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": + def from_pb( + cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device + ) -> "NextTokenChooser": # handle protobuf making default values 0 seed = pb.seed if pb.HasField("seed") else None return NextTokenChooser( @@ -89,6 +97,7 @@ class NextTokenChooser: top_p=pb.top_p, do_sample=pb.do_sample, seed=seed, + device=str(device), )