Johnny-Z's picture
Upload 3 files
b2a709a verified
raw
history blame
5.36 kB
import gradio as gr
import numpy as np
import torch
from transformers import AutoModel, BitImageProcessor, SiglipImageProcessor, SiglipVisionModel
from PIL import Image, ImageOps
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
device = torch.device('cpu')
torch.set_num_threads(4)
processor_d = BitImageProcessor(do_center_crop=False, do_convert_rgb=False, do_normalize=True, do_rescale=True, do_resize=False, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], resample=3, rescale_factor=0.00392156862745098)
model_d = AutoModel.from_pretrained('facebook/dinov2-base', attn_implementation="sdpa").to(device)
processor_s = SiglipImageProcessor.from_pretrained('google/siglip-so400m-patch14-384')
model_s = SiglipVisionModel.from_pretrained('google/siglip-so400m-patch14-384', attn_implementation="sdpa").to(device)
class ResidualBlock(nn.Module):
def __init__(self, input_size):
super(ResidualBlock, self).__init__()
self.linear1 = nn.Linear(input_size, input_size // 2)
self.LayerNorm1 = nn.LayerNorm(input_size // 2)
self.activation1 = nn.Mish()
self.linear2 = nn.Linear(input_size // 2, input_size // 4)
self.LayerNorm2 = nn.LayerNorm(input_size // 4)
self.activation2 = nn.Mish()
self.linear3 = nn.Linear(input_size // 4, input_size // 2)
self.LayerNorm3 = nn.LayerNorm(input_size // 2)
self.activation3 = nn.Mish()
self.linear4 = nn.Linear(input_size // 2, input_size)
self.LayerNorm4 = nn.LayerNorm(input_size)
self.activation4 = nn.Mish()
self.shortcut = nn.Linear(input_size, input_size)
def forward(self, x):
identity = self.shortcut(x)
out = self.linear1(x)
out = self.LayerNorm1(out)
out = self.activation1(out)
out = self.linear2(out)
out = self.LayerNorm2(out)
out = self.activation2(out)
out = self.linear3(out)
out = self.LayerNorm3(out)
out = self.activation3(out)
out = self.linear4(out)
out = self.LayerNorm4(out)
out = self.activation4(out)
out += identity
return out
class MLP(nn.Module):
def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
super().__init__()
self.input_size = input_size
self.xcol = xcol
self.ycol = ycol
self.layers = nn.Sequential(
ResidualBlock(self.input_size),
nn.Mish(),
nn.Linear(1920, 1)
)
def forward(self, x):
return self.layers(x)
mlp = MLP(1920)
s = torch.load("./aesthetic_predictor_huber_ad_ep7.pth", map_location=torch.device('cpu'))
mlp.load_state_dict(s)
mlp.to(device)
mlp.eval()
def normalized(a, axis=-1, order=2):
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
def process_image(image, device):
image = image.convert('RGBA')
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
image = Image.alpha_composite(background, image).convert('RGB')
max_side = 518
ratio = max_side / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image_d = image.resize(new_size, Image.LANCZOS)
max_side_s = 384
ratio_s = max_side_s / max(image.size)
new_size_s = (int(image.size[0] * ratio_s), int(image.size[1] * ratio_s))
image_resized = image.resize(new_size_s, Image.LANCZOS)
image_s = ImageOps.pad(image_resized, (384, 384), color=(255, 255, 255))
inputs_d = processor_d(image_d, return_tensors="pt").to(device)
inputs_s = processor_s(image_s, return_tensors="pt").to(device)
with torch.no_grad():
outputs_d = model_d(**inputs_d)
outputs_s = model_s(**inputs_s)
class_token_d = normalized(outputs_d.pooler_output.cpu().detach().numpy())
class_token_s = normalized(outputs_s.pooler_output.cpu().detach().numpy())
im_emb_arr = np.concatenate((class_token_s, class_token_d), axis=1)
prediction_value = mlp(torch.from_numpy(im_emb_arr).to(device).type(torch.FloatTensor)).item()
return im_emb_arr, prediction_value
def infer(image1, image2):
try:
features1, prediction_value1 = process_image(image1, device)
features2, prediction_value2 = process_image(image2, device)
cos_sim_features = cosine_similarity(features1, features2)[0][0]
return cos_sim_features, prediction_value1, prediction_value2
except Exception as e:
print(f"Error during inference: {e}")
return "Error", "Error", "Error"
with gr.Blocks() as iface:
gr.Markdown("# Anime Aesthetic Predictor Based on Twitter User Preferences\nUpload two images to calculate the aesthetic score (0-10).")
with gr.Row():
image1 = gr.Image(type="pil")
image2 = gr.Image(type="pil")
with gr.Row():
prediction1 = gr.Textbox(label="Aesthetic Score 1")
prediction2 = gr.Textbox(label="Aesthetic Score 2")
with gr.Row():
feature_similarity = gr.Textbox(label="Feature Similarity")
with gr.Row():
submit_btn = gr.Button("Submit")
submit_btn.click(infer, inputs=[image1, image2], outputs=[feature_similarity, prediction1, prediction2])
iface.queue(max_size=10)
iface.launch()