Device to use (e.g. cpu, cuda:0, cuda:1, etc.) (#1844)
* Device to use (e.g. cpu, cuda:0, cuda:1, etc.) * "cuda" if torch.cuda.is_available() else "cpu"
This commit is contained in:
parent
df2b548e89
commit
1f1b6c6544
|
@ -848,12 +848,17 @@ if __name__ == "__main__":
|
|||
),
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||
args = parser.parse_args()
|
||||
|
||||
image_size = args.image_size
|
||||
prediction_type = args.prediction_type
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
if args.device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||
else:
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
|
|
Loading…
Reference in New Issue