Fixing linters. (#2650)
This commit is contained in:
parent
58848cb471
commit
cf04a43fb1
|
@ -619,18 +619,11 @@ class CausalLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=("auto" if device_count > 1 else None),
|
||||||
"auto"
|
|
||||||
if device_count > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if (
|
if device_count == 1 and quantize != "bitsandbytes":
|
||||||
device_count == 1
|
|
||||||
and quantize != "bitsandbytes"
|
|
||||||
):
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
|
|
|
@ -649,11 +649,7 @@ class Seq2SeqLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map=(
|
device_map=("auto" if device_count > 1 else None),
|
||||||
"auto"
|
|
||||||
if device_count > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue