update docker tags
This commit is contained in:
parent
840493037e
commit
bf3c022489
|
@ -59,6 +59,10 @@ jobs:
|
|||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=raw,value=cuda121
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=schedule,pattern=nightly
|
||||
type=sha,format=long
|
||||
|
||||
# Build and push Docker image with Buildx
|
||||
# https://github.com/docker/build-push-action
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
import os
|
||||
import torch
|
||||
import safetensors as st
|
||||
from safetensors.torch import save_file
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--ckpt", type=str, required=True, help="Path to the checkpoint file")
|
||||
args = argparser.parse_args()
|
||||
|
||||
print(f"Loading model {args.ckpt}")
|
||||
model = torch.load(f"{args.ckpt}")
|
||||
print("Model loaded.")
|
||||
base_name = os.path.splitext(args.ckpt)[0]
|
||||
print(f"base_name: {base_name}")
|
||||
|
||||
model = model.pop("state_dict", model)
|
||||
save_file(model, f"{base_name}.safetensors")
|
||||
print(f"File converted to safetensors: {base_name}.safetensors")
|
Loading…
Reference in New Issue