File size: 5,357 Bytes
b2a709a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import numpy as np
import torch
from transformers import AutoModel, BitImageProcessor, SiglipImageProcessor, SiglipVisionModel
from PIL import Image, ImageOps
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn

device = torch.device('cpu')
torch.set_num_threads(4)

processor_d = BitImageProcessor(do_center_crop=False, do_convert_rgb=False, do_normalize=True, do_rescale=True, do_resize=False, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], resample=3, rescale_factor=0.00392156862745098)
model_d = AutoModel.from_pretrained('facebook/dinov2-base', attn_implementation="sdpa").to(device)
processor_s = SiglipImageProcessor.from_pretrained('google/siglip-so400m-patch14-384')
model_s = SiglipVisionModel.from_pretrained('google/siglip-so400m-patch14-384', attn_implementation="sdpa").to(device)

class ResidualBlock(nn.Module):
    def __init__(self, input_size):
        super(ResidualBlock, self).__init__()
        self.linear1 = nn.Linear(input_size, input_size // 2)
        self.LayerNorm1 = nn.LayerNorm(input_size // 2)
        self.activation1 = nn.Mish()

        self.linear2 = nn.Linear(input_size // 2, input_size // 4)
        self.LayerNorm2 = nn.LayerNorm(input_size // 4)
        self.activation2 = nn.Mish()

        self.linear3 = nn.Linear(input_size // 4, input_size // 2)
        self.LayerNorm3 = nn.LayerNorm(input_size // 2)
        self.activation3 = nn.Mish()

        self.linear4 = nn.Linear(input_size // 2, input_size)
        self.LayerNorm4 = nn.LayerNorm(input_size)
        self.activation4 = nn.Mish()

        self.shortcut = nn.Linear(input_size, input_size)

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.linear1(x)
        out = self.LayerNorm1(out)
        out = self.activation1(out)

        out = self.linear2(out)
        out = self.LayerNorm2(out)
        out = self.activation2(out)
        
        out = self.linear3(out)
        out = self.LayerNorm3(out)
        out = self.activation3(out)
        
        out = self.linear4(out)
        out = self.LayerNorm4(out)
        out = self.activation4(out)
        
        out += identity

        return out

class MLP(nn.Module):
    def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = nn.Sequential(
            ResidualBlock(self.input_size),
            nn.Mish(),
            nn.Linear(1920, 1)
        )

    def forward(self, x):
        return self.layers(x)

mlp = MLP(1920) 

s = torch.load("./aesthetic_predictor_huber_ad_ep7.pth", map_location=torch.device('cpu'))   

mlp.load_state_dict(s)

mlp.to(device)

mlp.eval()

def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)

def process_image(image, device):
    image = image.convert('RGBA')
    background = Image.new('RGBA', image.size, (255, 255, 255, 255))
    image = Image.alpha_composite(background, image).convert('RGB')

    max_side = 518
    ratio = max_side / max(image.size)
    new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
    image_d = image.resize(new_size, Image.LANCZOS)
    max_side_s = 384
    ratio_s = max_side_s / max(image.size)
    new_size_s = (int(image.size[0] * ratio_s), int(image.size[1] * ratio_s))
    image_resized = image.resize(new_size_s, Image.LANCZOS)

    image_s = ImageOps.pad(image_resized, (384, 384), color=(255, 255, 255))

    inputs_d = processor_d(image_d, return_tensors="pt").to(device)
    inputs_s = processor_s(image_s, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs_d = model_d(**inputs_d)
        outputs_s = model_s(**inputs_s)
    class_token_d = normalized(outputs_d.pooler_output.cpu().detach().numpy())
    class_token_s = normalized(outputs_s.pooler_output.cpu().detach().numpy())
    im_emb_arr = np.concatenate((class_token_s, class_token_d), axis=1)

    prediction_value = mlp(torch.from_numpy(im_emb_arr).to(device).type(torch.FloatTensor)).item()

    return im_emb_arr, prediction_value


def infer(image1, image2):
    try:

        features1, prediction_value1 = process_image(image1, device)
        features2, prediction_value2 = process_image(image2, device)

        cos_sim_features = cosine_similarity(features1, features2)[0][0]

        return cos_sim_features, prediction_value1, prediction_value2
    except Exception as e:
        print(f"Error during inference: {e}")
        return "Error", "Error", "Error"


with gr.Blocks() as iface:
    gr.Markdown("# Anime Aesthetic Predictor Based on Twitter User Preferences\nUpload two images to calculate the aesthetic score (0-10).")
    with gr.Row():
        image1 = gr.Image(type="pil")
        image2 = gr.Image(type="pil")
    with gr.Row():
        prediction1 = gr.Textbox(label="Aesthetic Score 1")
        prediction2 = gr.Textbox(label="Aesthetic Score 2")
    with gr.Row():
        feature_similarity = gr.Textbox(label="Feature Similarity")
    with gr.Row():
        submit_btn = gr.Button("Submit")

    submit_btn.click(infer, inputs=[image1, image2], outputs=[feature_similarity, prediction1, prediction2])
    
iface.queue(max_size=10)
iface.launch()