From bf3c02248966684a87f1117f7982d9089e6e1ea6 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Thu, 16 Nov 2023 12:04:39 -0500 Subject: [PATCH] update docker tags --- .github/workflows/docker-publish.yml | 4 ++++ scripts/conv_ckpt_to_safetensors.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 scripts/conv_ckpt_to_safetensors.py diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index c03f225..535109f 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -59,6 +59,10 @@ jobs: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} tags: | type=raw,value=cuda121 + type=ref,event=branch + type=ref,event=tag + type=schedule,pattern=nightly + type=sha,format=long # Build and push Docker image with Buildx # https://github.com/docker/build-push-action diff --git a/scripts/conv_ckpt_to_safetensors.py b/scripts/conv_ckpt_to_safetensors.py new file mode 100644 index 0000000..0326d0c --- /dev/null +++ b/scripts/conv_ckpt_to_safetensors.py @@ -0,0 +1,20 @@ +import os +import torch +import safetensors as st +from safetensors.torch import save_file +import argparse + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--ckpt", type=str, required=True, help="Path to the checkpoint file") + args = argparser.parse_args() + + print(f"Loading model {args.ckpt}") + model = torch.load(f"{args.ckpt}") + print("Model loaded.") + base_name = os.path.splitext(args.ckpt)[0] + print(f"base_name: {base_name}") + + model = model.pop("state_dict", model) + save_file(model, f"{base_name}.safetensors") + print(f"File converted to safetensors: {base_name}.safetensors") \ No newline at end of file