diff --git a/.gitignore b/.gitignore index 6dafcb9..d00aed5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ # Programming - general *.log +example.png +scores.json # =========================================================================== # diff --git a/aesthetics/aesthetics.py b/aesthetics/aesthetics.py new file mode 100644 index 0000000..5b85840 --- /dev/null +++ b/aesthetics/aesthetics.py @@ -0,0 +1,142 @@ +import webdataset as wds +from PIL import Image +import io +import matplotlib.pyplot as plt +import os +import json + +from warnings import filterwarnings + + +os.environ["CUDA_VISIBLE_DEVICES"] = "1" # choose GPU if you are on a multi GPU server +import numpy as np +import torch +import pytorch_lightning as pl +import torch.nn as nn +from torchvision import datasets, transforms +import tqdm + +from os.path import join +from datasets import load_dataset +import pandas as pd +from torch.utils.data import Dataset, DataLoader +import json + +import clip + + +from PIL import Image, ImageFile + + +##### This script will predict the aesthetic score for this image file: + +img_path = "../250k_data-0/img/000baa665498e7a61130d7662f81e698.jpg" + + + + + +# if you changed the MLP architecture during training, change it also here: +class MLP(pl.LightningModule): + def __init__(self, input_size, xcol='emb', ycol='avg_rating'): + super().__init__() + self.input_size = input_size + self.xcol = xcol + self.ycol = ycol + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + #nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, 128), + #nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(128, 64), + #nn.ReLU(), + nn.Dropout(0.1), + + nn.Linear(64, 16), + #nn.ReLU(), + + nn.Linear(16, 1) + ) + + def forward(self, x): + return self.layers(x) + + def training_step(self, batch, batch_idx): + x = batch[self.xcol] + y = batch[self.ycol].reshape(-1, 1) + x_hat = self.layers(x) + loss = F.mse_loss(x_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x = batch[self.xcol] + y = batch[self.ycol].reshape(-1, 1) + x_hat = self.layers(x) + loss = F.mse_loss(x_hat, y) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + +def normalized(a, axis=-1, order=2): + import numpy as np # pylint: disable=import-outside-toplevel + + l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) + l2[l2 == 0] = 1 + return a / np.expand_dims(l2, axis) + + +model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 + +s = torch.load("sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo + +model.load_state_dict(s) + +model.to("cuda") +model.eval() + + +device = "cuda" if torch.cuda.is_available() else "cpu" +model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64 + +@torch.inference_mode() +def aesthetic(img_path): + pil_image = Image.open(img_path) + image = preprocess(pil_image).unsqueeze(0).to(device) + with torch.no_grad(): + image_features = model2.encode_image(image) + im_emb_arr = normalized(image_features.cpu().detach().numpy()) + prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor)) + return prediction.item() + +import json +import glob +import shutil + +imdir = '../250k_data-0/img/' +ext = ['png', 'jpg', 'jpeg', 'bmp'] +images = [] +[images.extend(glob.glob(imdir + '*.' + e)) for e in ext] + +aesthetic_scores = {} + +try: + for i in tqdm.tqdm(images): + try: + score = aesthetic(i) + except: + print(f'skipping {i}') + continue + if score < 5.0: + shutil.move(i, i.replace('img', 'nonaesthetic')) + elif score > 6.0: + shutil.move(i, i.replace('img', 'aesthetic')) + aesthetic_scores[i] = score +except KeyboardInterrupt: + pass +finally: + with open('scores.json', 'w') as f: + f.write(json.dumps(aesthetic_scores)) diff --git a/aesthetics/sac+logos+ava1-l14-linearMSE.pth b/aesthetics/sac+logos+ava1-l14-linearMSE.pth new file mode 100644 index 0000000..7c0d8aa Binary files /dev/null and b/aesthetics/sac+logos+ava1-l14-linearMSE.pth differ diff --git a/ldm/data/local.py b/ldm/data/local.py new file mode 100644 index 0000000..799ecf7 --- /dev/null +++ b/ldm/data/local.py @@ -0,0 +1,117 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import glob + +import random + + +class LocalBase(Dataset): + def __init__(self, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + super().__init__() + + print('Fetching data.') + + ext = ['png', 'jpg', 'jpeg', 'bmp'] + self.image_files = [] + [self.image_files.extend(glob.glob(f'{data_root}/img/' + '*.' + e)) for e in ext] + + print('Constructing image-caption map.') + + self.examples = {} + self.hashes = [] + for i in self.image_files: + hash = i[len(f'{data_root}/img/'):].split('.')[0] + self.examples[hash] = { + 'image': i, + 'text': f'{data_root}/txt/{hash}.txt' + } + self.hashes.append(hash) + + print(f'image-caption map has {len(self.examples.keys())} examples') + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def random_sample(self): + return self.__getitem__(random.randint(0, self.__len__() - 1)) + + def sequential_sample(self, i): + if i >= self.__len__() - 1: + return self.__getitem__(0) + return self.__getitem__(i + 1) + + def skip_sample(self, i): + if self.shuffle: + return self.random_sample() + return self.sequential_sample(i=i) + + def get_caption(self, i): + example = self.examples[self.hashes[i]] + caption = open(example['text'], 'r').read() + caption = caption.replace(' ', ' ').replace('\n', ' ').lstrip().rstrip() + return caption + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, i): + example_ret = {} + try: + image_file = self.examples[self.hashes[i]]['image'] + image = Image.open(image_file) + if not image.mode == "RGB": + image = image.convert("RGB") + except (OSError, ValueError) as e: + print(f'Error with {image_file} -- skipping {i}') + return self.skip_sample(i) + + try: + caption = self.get_caption(i) + if caption == None: + raise ValueError + except (OSError, ValueError) as e: + print(f'Error with caption of {image_file} -- skipping {i}') + return self.skip_sample(i) + + example_ret['caption'] = caption + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example_ret["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example_ret + + +if __name__ == "__main__": + dataset = LocalBase('../250k_data-0', size=512) + example = dataset.__getitem__(137) + print(example['caption']) + image = example['image'] + image = ((image + 1) * 127.5).astype(np.uint8) + image = Image.fromarray(image) + image.save('example.png') \ No newline at end of file