Update README.md
Browse files
README.md
CHANGED
@@ -23,4 +23,117 @@ User interface (UI) design is a difficult yet important task for ensuring the us
|
|
23 |
- **Developed by:** BigLab
|
24 |
- **Model type:** CLIP-style Multi-modal Dual-encoder Transformer
|
25 |
- **Language(s) (NLP):** English
|
26 |
-
- **License:** MIT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
- **Developed by:** BigLab
|
24 |
- **Model type:** CLIP-style Multi-modal Dual-encoder Transformer
|
25 |
- **Language(s) (NLP):** English
|
26 |
+
- **License:** MIT
|
27 |
+
|
28 |
+
|
29 |
+
```python
|
30 |
+
import torch
|
31 |
+
from transformers import CLIPProcessor, CLIPModel
|
32 |
+
|
33 |
+
IMG_SIZE = 224
|
34 |
+
DEVICE = "cpu" # can also be "cuda" or "mps"
|
35 |
+
LOGIT_SCALE = 100 # based on OpenAI's CLIP example code
|
36 |
+
NORMALIZE_SCORING = True
|
37 |
+
|
38 |
+
model_path="uiclip_jitteredwebsites-2-224-paraphrased" # can also be webpairs or human pairs variants
|
39 |
+
processor_path="openai/clip-vit-base-patch32"
|
40 |
+
|
41 |
+
model = CLIPModel.from_pretrained(model_path)
|
42 |
+
model = model.eval()
|
43 |
+
model = model.to(DEVICE)
|
44 |
+
|
45 |
+
processor = CLIPProcessor.from_pretrained(processor_path)
|
46 |
+
|
47 |
+
def compute_quality_scores(input_list):
|
48 |
+
# input_list is a list of types where the first element is a description and the second is a PIL image
|
49 |
+
description_list = ["ui screenshot. well-designed. " + input_item[0] for input_item in input_list]
|
50 |
+
img_list = [input_item[1] for input_item in input_list]
|
51 |
+
text_embeddings_tensor = compute_description_embeddings(description_list) # B x H
|
52 |
+
img_embeddings_tensor = compute_image_embeddings(img_list) # B x H
|
53 |
+
|
54 |
+
# normalize tensors
|
55 |
+
text_embeddings_tensor /= text_embeddings_tensor.norm(dim=-1, keepdim=True)
|
56 |
+
img_embeddings_tensor /= img_embeddings_tensor.norm(dim=-1, keepdim=True)
|
57 |
+
|
58 |
+
if NORMALIZE_SCORING:
|
59 |
+
text_embeddings_tensor_poor = compute_description_embeddings([d.replace("well-designed. ", "poor design. ") for d in description_list]) # B x H
|
60 |
+
text_embeddings_tensor_poor /= text_embeddings_tensor_poor.norm(dim=-1, keepdim=True)
|
61 |
+
text_embeddings_tensor_all = torch.stack((text_embeddings_tensor, text_embeddings_tensor_poor), dim=1) # B x 2 x H
|
62 |
+
else:
|
63 |
+
text_embeddings_tensor_all = text_embeddings_tensor.unsqueeze(1)
|
64 |
+
|
65 |
+
img_embeddings_tensor = img_embeddings_tensor.unsqueeze(1) # B x 1 x H
|
66 |
+
|
67 |
+
scores = (LOGIT_SCALE * img_embeddings_tensor @ text_embeddings_tensor_all.permute(0, 2, 1)).squeeze(1)
|
68 |
+
|
69 |
+
if NORMALIZE_SCORING:
|
70 |
+
scores = scores.softmax(dim=-1)
|
71 |
+
|
72 |
+
return scores[:, 0]
|
73 |
+
|
74 |
+
def compute_description_embeddings(descriptions):
|
75 |
+
inputs = processor(text=descriptions, return_tensors="pt", padding=True)
|
76 |
+
inputs['input_ids'] = inputs['input_ids'].to(DEVICE)
|
77 |
+
inputs['attention_mask'] = inputs['attention_mask'].to(DEVICE)
|
78 |
+
text_embedding = model.get_text_features(**inputs)
|
79 |
+
return text_embedding
|
80 |
+
|
81 |
+
def compute_image_embeddings(image_list):
|
82 |
+
windowed_batch = [slide_window_over_image(img, IMG_SIZE) for img in image_list]
|
83 |
+
inds = []
|
84 |
+
for imgi in range(len(windowed_batch)):
|
85 |
+
inds.append([imgi for _ in windowed_batch[imgi]])
|
86 |
+
|
87 |
+
processed_batch = [item for sublist in windowed_batch for item in sublist]
|
88 |
+
inputs = processor(images=processed_batch, return_tensors="pt")
|
89 |
+
# run all sub windows of all images in batch through the model
|
90 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(DEVICE)
|
91 |
+
with torch.no_grad():
|
92 |
+
image_features = model.get_image_features(**inputs)
|
93 |
+
|
94 |
+
# output contains all subwindows, need to mask for each image
|
95 |
+
processed_batch_inds = torch.tensor([item for sublist in inds for item in sublist]).long().to(image_features.device)
|
96 |
+
embed_list = []
|
97 |
+
for i in range(len(image_list)):
|
98 |
+
mask = processed_batch_inds == i
|
99 |
+
embed_list.append(image_features[mask].mean(dim=0))
|
100 |
+
image_embedding = torch.stack(embed_list, dim=0)
|
101 |
+
return image_embedding
|
102 |
+
|
103 |
+
def preresize_image(image, image_size):
|
104 |
+
aspect_ratio = image.width / image.height
|
105 |
+
if aspect_ratio > 1:
|
106 |
+
image = image.resize((int(aspect_ratio * image_size), image_size))
|
107 |
+
else:
|
108 |
+
image = image.resize((image_size, int(image_size / aspect_ratio)))
|
109 |
+
return image
|
110 |
+
|
111 |
+
def slide_window_over_image(input_image, img_size):
|
112 |
+
input_image = preresize_image(input_image, img_size)
|
113 |
+
width, height = input_image.size
|
114 |
+
square_size = min(width, height)
|
115 |
+
longer_dimension = max(width, height)
|
116 |
+
num_steps = (longer_dimension + square_size - 1) // square_size
|
117 |
+
|
118 |
+
if num_steps > 1:
|
119 |
+
step_size = (longer_dimension - square_size) // (num_steps - 1)
|
120 |
+
else:
|
121 |
+
step_size = square_size
|
122 |
+
|
123 |
+
cropped_images = []
|
124 |
+
|
125 |
+
for y in range(0, height - square_size + 1, step_size if height > width else square_size):
|
126 |
+
for x in range(0, width - square_size + 1, step_size if width > height else square_size):
|
127 |
+
left = x
|
128 |
+
upper = y
|
129 |
+
right = x + square_size
|
130 |
+
lower = y + square_size
|
131 |
+
cropped_image = input_image.crop((left, upper, right, lower))
|
132 |
+
cropped_images.append(cropped_image)
|
133 |
+
|
134 |
+
return cropped_images
|
135 |
+
|
136 |
+
|
137 |
+
# compute the quality scores for a list of descriptions (strings) and images (PIL images)
|
138 |
+
prediction_scores = compute_quality_scores(list(zip(test_descriptions, test_images)))
|
139 |
+
```
|