Device to use (e.g. cpu, cuda:0, cuda:1, etc.) (#1844)
* Device to use (e.g. cpu, cuda:0, cuda:1, etc.) * "cuda" if torch.cuda.is_available() else "cpu"
This commit is contained in:
parent
df2b548e89
commit
1f1b6c6544
|
@ -848,12 +848,17 @@ if __name__ == "__main__":
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
|
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
image_size = args.image_size
|
image_size = args.image_size
|
||||||
prediction_type = args.prediction_type
|
prediction_type = args.prediction_type
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint_path)
|
if args.device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
|
||||||
|
|
||||||
# Sometimes models don't have the global_step item
|
# Sometimes models don't have the global_step item
|
||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
|
|
Loading…
Reference in New Issue