add local dataset code
This commit is contained in:
parent
b1a57dce78
commit
e6c1e048ea
|
@ -6,6 +6,8 @@
|
|||
|
||||
# Programming - general
|
||||
*.log
|
||||
example.png
|
||||
scores.json
|
||||
|
||||
|
||||
# =========================================================================== #
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
import webdataset as wds
|
||||
from PIL import Image
|
||||
import io
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import json
|
||||
|
||||
from warnings import filterwarnings
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # choose GPU if you are on a multi GPU server
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn as nn
|
||||
from torchvision import datasets, transforms
|
||||
import tqdm
|
||||
|
||||
from os.path import join
|
||||
from datasets import load_dataset
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import json
|
||||
|
||||
import clip
|
||||
|
||||
|
||||
from PIL import Image, ImageFile
|
||||
|
||||
|
||||
##### This script will predict the aesthetic score for this image file:
|
||||
|
||||
img_path = "../250k_data-0/img/000baa665498e7a61130d7662f81e698.jpg"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# if you changed the MLP architecture during training, change it also here:
|
||||
class MLP(pl.LightningModule):
|
||||
def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.xcol = xcol
|
||||
self.ycol = ycol
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(self.input_size, 1024),
|
||||
#nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(1024, 128),
|
||||
#nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
#nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(64, 16),
|
||||
#nn.ReLU(),
|
||||
|
||||
nn.Linear(16, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x = batch[self.xcol]
|
||||
y = batch[self.ycol].reshape(-1, 1)
|
||||
x_hat = self.layers(x)
|
||||
loss = F.mse_loss(x_hat, y)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x = batch[self.xcol]
|
||||
y = batch[self.ycol].reshape(-1, 1)
|
||||
x_hat = self.layers(x)
|
||||
loss = F.mse_loss(x_hat, y)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
def normalized(a, axis=-1, order=2):
|
||||
import numpy as np # pylint: disable=import-outside-toplevel
|
||||
|
||||
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
|
||||
l2[l2 == 0] = 1
|
||||
return a / np.expand_dims(l2, axis)
|
||||
|
||||
|
||||
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
|
||||
|
||||
s = torch.load("sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo
|
||||
|
||||
model.load_state_dict(s)
|
||||
|
||||
model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64
|
||||
|
||||
@torch.inference_mode()
|
||||
def aesthetic(img_path):
|
||||
pil_image = Image.open(img_path)
|
||||
image = preprocess(pil_image).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model2.encode_image(image)
|
||||
im_emb_arr = normalized(image_features.cpu().detach().numpy())
|
||||
prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor))
|
||||
return prediction.item()
|
||||
|
||||
import json
|
||||
import glob
|
||||
import shutil
|
||||
|
||||
imdir = '../250k_data-0/img/'
|
||||
ext = ['png', 'jpg', 'jpeg', 'bmp']
|
||||
images = []
|
||||
[images.extend(glob.glob(imdir + '*.' + e)) for e in ext]
|
||||
|
||||
aesthetic_scores = {}
|
||||
|
||||
try:
|
||||
for i in tqdm.tqdm(images):
|
||||
try:
|
||||
score = aesthetic(i)
|
||||
except:
|
||||
print(f'skipping {i}')
|
||||
continue
|
||||
if score < 5.0:
|
||||
shutil.move(i, i.replace('img', 'nonaesthetic'))
|
||||
elif score > 6.0:
|
||||
shutil.move(i, i.replace('img', 'aesthetic'))
|
||||
aesthetic_scores[i] = score
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
with open('scores.json', 'w') as f:
|
||||
f.write(json.dumps(aesthetic_scores))
|
Binary file not shown.
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
import glob
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class LocalBase(Dataset):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
size=None,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
print('Fetching data.')
|
||||
|
||||
ext = ['png', 'jpg', 'jpeg', 'bmp']
|
||||
self.image_files = []
|
||||
[self.image_files.extend(glob.glob(f'{data_root}/img/' + '*.' + e)) for e in ext]
|
||||
|
||||
print('Constructing image-caption map.')
|
||||
|
||||
self.examples = {}
|
||||
self.hashes = []
|
||||
for i in self.image_files:
|
||||
hash = i[len(f'{data_root}/img/'):].split('.')[0]
|
||||
self.examples[hash] = {
|
||||
'image': i,
|
||||
'text': f'{data_root}/txt/{hash}.txt'
|
||||
}
|
||||
self.hashes.append(hash)
|
||||
|
||||
print(f'image-caption map has {len(self.examples.keys())} examples')
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def random_sample(self):
|
||||
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
||||
|
||||
def sequential_sample(self, i):
|
||||
if i >= self.__len__() - 1:
|
||||
return self.__getitem__(0)
|
||||
return self.__getitem__(i + 1)
|
||||
|
||||
def skip_sample(self, i):
|
||||
if self.shuffle:
|
||||
return self.random_sample()
|
||||
return self.sequential_sample(i=i)
|
||||
|
||||
def get_caption(self, i):
|
||||
example = self.examples[self.hashes[i]]
|
||||
caption = open(example['text'], 'r').read()
|
||||
caption = caption.replace(' ', ' ').replace('\n', ' ').lstrip().rstrip()
|
||||
return caption
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_files)
|
||||
|
||||
def __getitem__(self, i):
|
||||
example_ret = {}
|
||||
try:
|
||||
image_file = self.examples[self.hashes[i]]['image']
|
||||
image = Image.open(image_file)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
except (OSError, ValueError) as e:
|
||||
print(f'Error with {image_file} -- skipping {i}')
|
||||
return self.skip_sample(i)
|
||||
|
||||
try:
|
||||
caption = self.get_caption(i)
|
||||
if caption == None:
|
||||
raise ValueError
|
||||
except (OSError, ValueError) as e:
|
||||
print(f'Error with caption of {image_file} -- skipping {i}')
|
||||
return self.skip_sample(i)
|
||||
|
||||
example_ret['caption'] = caption
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = img.shape[0], img.shape[1]
|
||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example_ret["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example_ret
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = LocalBase('../250k_data-0', size=512)
|
||||
example = dataset.__getitem__(137)
|
||||
print(example['caption'])
|
||||
image = example['image']
|
||||
image = ((image + 1) * 127.5).astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
image.save('example.png')
|
Loading…
Reference in New Issue