rxavier commited on
Commit
9cdb8b2
1 Parent(s): a377f0d

Update off_topic.py

Browse files
Files changed (1) hide show
  1. off_topic.py +136 -0
off_topic.py CHANGED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ import asyncio
4
+ from io import BytesIO
5
+ from typing import List, Optional
6
+
7
+ import httpx
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ import PIL
12
+ from transformers import CLIPModel, CLIPProcessor
13
+ from PIL import Image
14
+
15
+
16
+ class OffTopicDetector:
17
+ def __init__(self, model_id: str, device: Optional[str] = None):
18
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.processor = CLIPProcessor.from_pretrained(model_id)
20
+ self.model = CLIPModel.from_pretrained(model_id).to(self.device)
21
+
22
+ def predict_probas(self, images: List[PIL.Image.Image], domain: str,
23
+ valid_templates: Optional[List[str]] = None,
24
+ invalid_classes: Optional[List[str]] = None,
25
+ autocast: bool = True):
26
+ if valid_templates:
27
+ valid_classes = [template.format(domain) for template in valid_templates]
28
+ else:
29
+ valid_classes = [f"a photo of {domain}", f"brochure with {domain} image", f"instructions for {domain}", f"{domain} diagram"]
30
+ if not invalid_classes:
31
+ invalid_classes = ["promotional ad with store information", "promotional text", "google maps screenshot", "business card", "qr code"]
32
+ n_valid = len(valid_classes)
33
+ classes = valid_classes + invalid_classes
34
+ print(f"Valid classes: {valid_classes}", f"Invalid classes: {invalid_classes}", sep="\n")
35
+ n_classes = len(classes)
36
+
37
+ start = time.time()
38
+ inputs = self.processor(text=classes, images=images, return_tensors="pt", padding=True).to(self.device)
39
+ if self.device == "cpu" and autocast is True:
40
+ print("Disabling autocast due to device='cpu'.")
41
+ autocast = False
42
+ with torch.autocast(self.device, enabled=autocast):
43
+ with torch.no_grad():
44
+ outputs = self.model(**inputs)
45
+ probas = outputs.logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities
46
+ end = time.time()
47
+ duration = end - start
48
+ print(f"Device: {self.device}",
49
+ f"Response time: {duration}s",
50
+ f"Response time per image: {round(duration/len(images), 2) * 1000}ms",
51
+ sep="\n")
52
+ valid_probas = probas[:, 0:n_valid].sum(axis=1, keepdims=True)
53
+ invalid_probas = probas[:, n_valid:n_classes].sum(axis=1, keepdims=True)
54
+ return probas, valid_probas, invalid_probas
55
+
56
+ def show(self, images: List[PIL.Image.Image], valid_probas: np.ndarray, n_cols: int = 3, title: Optional[str] = None, threshold: Optional[float] = None):
57
+ if threshold is not None:
58
+ prediction = self.apply_threshold(valid_probas, threshold)
59
+ title_scores = [f"Valid: {pred.squeeze()}" for pred in prediction]
60
+ else:
61
+ prediction = np.round(valid_probas[:, 0], 2)
62
+ title_scores = [f"Valid: {pred:.2f}" for pred in prediction]
63
+ n_images = len(images)
64
+ n_rows = int(np.ceil(n_images / n_cols))
65
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 16))
66
+ for i, ax in enumerate(axes.ravel()):
67
+ ax.axis("off")
68
+ try:
69
+ ax.imshow(images[i])
70
+ ax.set_title(title_scores[i])
71
+ except IndexError:
72
+ continue
73
+ if title:
74
+ fig.suptitle(title)
75
+ fig.tight_layout()
76
+ return
77
+
78
+ def predict_item_probas(self, url_or_id: str,
79
+ valid_templates: Optional[List[str]] = None,
80
+ invalid_classes: Optional[List[str]] = None):
81
+ images, domain = self.get_item_data(url_or_id)
82
+ probas, valid_probas, invalid_probas = self.predict_probas(images, domain, valid_templates,
83
+ invalid_classes)
84
+ return images, domain, probas, valid_probas, invalid_probas
85
+
86
+ def apply_threshold(self, valid_probas: np.ndarray, threshold: float = 0.4):
87
+ return valid_probas >= threshold
88
+
89
+ def get_item_data(self, url_or_id: str):
90
+ if url_or_id.startswith("http"):
91
+ item_id = "".join(url_or_id.split("/")[3].split("-")[:2])
92
+ else:
93
+ item_id = re.sub("-", "", url_or_id)
94
+ response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
95
+ domain = re.sub("_", " ", response["domain_id"].split("-")[-1]).lower()
96
+ img_urls = [x["url"] for x in response["pictures"]]
97
+ images = self.get_images(img_urls)
98
+ return images, domain
99
+
100
+ def get_images(self, urls: List[str]):
101
+ start = time.time()
102
+ images = asyncio.run(self._gather_download_tasks(urls))
103
+ end = time.time()
104
+ duration = end - start
105
+ print(f"Download time: {duration}s",
106
+ f"Download time per image: {round(duration/len(urls), 2) * 1000}ms",
107
+ sep="\n")
108
+ return asyncio.run(self._gather_download_tasks(urls))
109
+
110
+ async def _gather_download_tasks(self, urls: List[str]):
111
+
112
+ async def _process_download(url: str, client: httpx.AsyncClient):
113
+ response = await client.get(url)
114
+ return Image.open(BytesIO(response.content))
115
+
116
+ async with httpx.AsyncClient() as client:
117
+ tasks = [_process_download(url, client) for url in urls]
118
+ return await asyncio.gather(*tasks)
119
+
120
+ @staticmethod
121
+ def _non_async_get_item_data(url_or_id: str, save_images: bool = False):
122
+ if url_or_id.startswith("http"):
123
+ item_id = "".join(url_or_id.split("/")[3].split("-")[:2])
124
+ else:
125
+ item_id = re.sub("-", "", url_or_id)
126
+ response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
127
+ domain = re.sub("_", " ", response["domain_id"].split("-")[-1]).lower()
128
+ img_urls = [x["url"] for x in response["pictures"]]
129
+ images = []
130
+ for img_url in img_urls:
131
+ img = httpx.get(img_url)
132
+ images.append(Image.open(BytesIO(img.content)))
133
+ if save_images:
134
+ with open(re.sub("D_NQ_NP_", "", img_url.split("/")[-1]) , "wb") as f:
135
+ f.write(img.content)
136
+ return images, domain