feat(server): do not use device_map auto on single GPU (#362)

This commit is contained in:
OlivierDehaene 2023-05-23 19:12:12 +02:00 committed by GitHub
parent cfaa858070
commit e9669a4085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 2 deletions

View File

@ -468,9 +468,12 @@ class CausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_8bit=quantize == "bitsandbytes",
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer.pad_token_id = (
model.config.pad_token_id
if model.config.pad_token_id is not None

View File

@ -518,9 +518,12 @@ class Seq2SeqLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_8bit=quantize == "bitsandbytes",
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)