Merge pull request #21 from lehmacdj/main

Add torch_device option to scripts/auto_caption.py
This commit is contained in:
Victor Hall 2023-03-05 18:03:36 -05:00 committed by GitHub
commit 8211f2ac90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 18 deletions

View File

@ -2,7 +2,7 @@
Automatic captioning uses Salesforce's BLIP to automatically create a clean sentence structure for captioning input images before training.
This requires an Nvidia GPU, but is not terribly intensive work. It should run fine on something like a 1050 Ti 4GB.
By default this requires an Nvidia GPU, but is not terribly intensive work. It should run fine on something like a 1050 Ti 4GB. You can even run this on the CPU by specifying `--torch_device cpu` as an argument. This will be slower than running on a Nvidia GPU, but will work even on Apple Silicon Macs.
[EveryDream trainer](https://github.com/victorchall/EveryDream-trainer) no longer requires cropped images. You only need to crop to exclude stuff you don't want trained, or to improve the portion of face close ups in your data. The EveryDream trainer now accepts multiple aspect ratios and can train on them natively.

View File

@ -66,6 +66,14 @@ def get_parser(**parser_kwargs):
default=22,
help="adjusts the likelihood of a word being repeated",
),
parser.add_argument(
"--torch_device",
type=str,
nargs="?",
const=False,
default="cuda",
help="specify a different torch device, e.g. 'cpu'",
),
return parser
@ -119,9 +127,9 @@ async def main(opt):
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
blip_decoder.eval()
print("loading model to cuda")
print(f"loading model to {opt.torch_device}")
blip_decoder = blip_decoder.to(torch.device("cuda"))
blip_decoder = blip_decoder.to(torch.device(opt.torch_device))
ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif')
@ -141,7 +149,7 @@ async def main(opt):
if not image.mode == "RGB":
image = image.convert("RGB")
image = load_image(image, device=torch.device("cuda"))
image = load_image(image, device=torch.device(opt.torch_device))
if opt.nucleus:
captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor)
@ -207,4 +215,3 @@ if __name__ == "__main__":
sys.path.append(blip_path)
asyncio.run(main(opt))