feat(server): do not use device_map auto on single GPU (#362)
This commit is contained in:
parent
cfaa858070
commit
e9669a4085
|
@ -468,9 +468,12 @@ class CausalLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
tokenizer.pad_token_id = (
|
tokenizer.pad_token_id = (
|
||||||
model.config.pad_token_id
|
model.config.pad_token_id
|
||||||
if model.config.pad_token_id is not None
|
if model.config.pad_token_id is not None
|
||||||
|
|
|
@ -518,9 +518,12 @@ class Seq2SeqLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue