rxavier commited on
Commit
1139c3b
1 Parent(s): 81502c7

Update off_topic.py

Browse files
Files changed (1) hide show
  1. off_topic.py +64 -35
off_topic.py CHANGED
@@ -13,12 +13,16 @@ import imagehash
13
  from transformers import CLIPModel, CLIPProcessor
14
  from PIL import Image
15
 
 
 
 
16
 
17
  class OffTopicDetector:
18
- def __init__(self, model_id: str, device: Optional[str] = None):
19
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
20
  self.processor = CLIPProcessor.from_pretrained(model_id)
21
  self.model = CLIPModel.from_pretrained(model_id).to(self.device)
 
22
 
23
  def predict_probas(self, images: List[PIL.Image.Image], domain: str,
24
  valid_templates: Optional[List[str]] = None,
@@ -35,48 +39,36 @@ class OffTopicDetector:
35
  print(f"Valid classes: {valid_classes}", f"Invalid classes: {invalid_classes}", sep="\n")
36
  n_classes = len(classes)
37
 
 
 
38
  start = time.time()
39
  inputs = self.processor(text=classes, images=images, return_tensors="pt", padding=True).to(self.device)
40
  if self.device == "cpu" and autocast is True:
41
- print("Disabling autocast due to device='cpu'.")
42
  autocast = False
43
  with torch.autocast(self.device, enabled=autocast):
44
  with torch.no_grad():
45
  outputs = self.model(**inputs)
46
  probas = outputs.logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities
 
 
47
  end = time.time()
48
  duration = end - start
49
- print(f"Device: {self.device}",
50
- f"Response time: {duration}s",
51
- f"Response time per image: {round(duration/len(images), 2) * 1000}ms",
52
  sep="\n")
53
  valid_probas = probas[:, 0:n_valid].sum(axis=1, keepdims=True)
54
  invalid_probas = probas[:, n_valid:n_classes].sum(axis=1, keepdims=True)
55
  return probas, valid_probas, invalid_probas
56
 
57
- def show(self, images: List[PIL.Image.Image], valid_probas: np.ndarray, n_cols: int = 3, title: Optional[str] = None, threshold: Optional[float] = None):
58
- if threshold is not None:
59
- prediction = self.apply_threshold(valid_probas, threshold)
60
- title_scores = [f"Valid: {pred.squeeze()}" for pred in prediction]
61
- else:
62
- prediction = np.round(valid_probas[:, 0], 2)
63
- title_scores = [f"Valid: {pred:.2f}" for pred in prediction]
64
- n_images = len(images)
65
- n_rows = int(np.ceil(n_images / n_cols))
66
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 16))
67
- for i, ax in enumerate(axes.ravel()):
68
- ax.axis("off")
69
- try:
70
- ax.imshow(images[i])
71
- ax.set_title(title_scores[i])
72
- except IndexError:
73
- continue
74
- if title:
75
- fig.suptitle(title)
76
- fig.tight_layout()
77
- return
78
 
79
- def predict_item_probas(self, url_or_id: str,
80
  valid_templates: Optional[List[str]] = None,
81
  invalid_classes: Optional[List[str]] = None):
82
  images, domain = self.get_item_data(url_or_id)
@@ -92,24 +84,38 @@ class OffTopicDetector:
92
  item_id = "".join(url_or_id.split("/")[3].split("-")[:2])
93
  else:
94
  item_id = re.sub("-", "", url_or_id)
 
95
  response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
96
  domain = re.sub("_", " ", response["domain_id"].split("-")[-1]).lower()
97
  img_urls = [x["url"] for x in response["pictures"]]
 
 
 
 
98
  images = self.get_images(img_urls)
99
- hashes = {}
100
- for img in images:
101
- hashes.update({str(imagehash.average_hash(img)): img})
102
- dedup_hashes = list(dict.fromkeys(hashes))
103
- dedup_images = [img for hash, img in hashes.items() if hash in dedup_hashes]
104
  return dedup_images, domain
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def get_images(self, urls: List[str]):
107
  start = time.time()
108
  images = asyncio.run(self._gather_download_tasks(urls))
109
  end = time.time()
110
  duration = end - start
111
- print(f"Download time: {duration}s",
112
- f"Download time per image: {round(duration/len(urls), 2) * 1000}ms",
113
  sep="\n")
114
  return asyncio.run(self._gather_download_tasks(urls))
115
 
@@ -139,4 +145,27 @@ class OffTopicDetector:
139
  if save_images:
140
  with open(re.sub("D_NQ_NP_", "", img_url.split("/")[-1]) , "wb") as f:
141
  f.write(img.content)
142
- return images, domain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from transformers import CLIPModel, CLIPProcessor
14
  from PIL import Image
15
 
16
+ import nest_asyncio
17
+ nest_asyncio.apply()
18
+
19
 
20
  class OffTopicDetector:
21
+ def __init__(self, model_id: str, device: Optional[str] = None, image_size: str = "E"):
22
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
23
  self.processor = CLIPProcessor.from_pretrained(model_id)
24
  self.model = CLIPModel.from_pretrained(model_id).to(self.device)
25
+ self.image_size = image_size
26
 
27
  def predict_probas(self, images: List[PIL.Image.Image], domain: str,
28
  valid_templates: Optional[List[str]] = None,
 
39
  print(f"Valid classes: {valid_classes}", f"Invalid classes: {invalid_classes}", sep="\n")
40
  n_classes = len(classes)
41
 
42
+ if self.device == "cuda":
43
+ torch.cuda.synchronize()
44
  start = time.time()
45
  inputs = self.processor(text=classes, images=images, return_tensors="pt", padding=True).to(self.device)
46
  if self.device == "cpu" and autocast is True:
 
47
  autocast = False
48
  with torch.autocast(self.device, enabled=autocast):
49
  with torch.no_grad():
50
  outputs = self.model(**inputs)
51
  probas = outputs.logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities
52
+ if self.device == "cuda":
53
+ torch.cuda.synchronize()
54
  end = time.time()
55
  duration = end - start
56
+ print(f"Model time: {round(duration, 2)} s",
57
+ f"Model time per image: {round(duration/len(images) * 1000, 0)} ms",
 
58
  sep="\n")
59
  valid_probas = probas[:, 0:n_valid].sum(axis=1, keepdims=True)
60
  invalid_probas = probas[:, n_valid:n_classes].sum(axis=1, keepdims=True)
61
  return probas, valid_probas, invalid_probas
62
 
63
+ def predict_probas_url(self, img_urls: List[str], domain: str,
64
+ valid_templates: Optional[List[str]] = None,
65
+ invalid_classes: Optional[List[str]] = None,
66
+ autocast: bool = True):
67
+ images = self.get_images(img_urls)
68
+ dedup_images = self._filter_dups(images)
69
+ return self.predict_probas(images, domain, valid_templates, invalid_classes, autocast)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ def predict_probas_item(self, url_or_id: str,
72
  valid_templates: Optional[List[str]] = None,
73
  invalid_classes: Optional[List[str]] = None):
74
  images, domain = self.get_item_data(url_or_id)
 
84
  item_id = "".join(url_or_id.split("/")[3].split("-")[:2])
85
  else:
86
  item_id = re.sub("-", "", url_or_id)
87
+ start = time.time()
88
  response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
89
  domain = re.sub("_", " ", response["domain_id"].split("-")[-1]).lower()
90
  img_urls = [x["url"] for x in response["pictures"]]
91
+ img_urls = [x.replace("-O.jpg", f"-{self.image_size}.jpg") for x in img_urls]
92
+ end = time.time()
93
+ duration = end - start
94
+ print(f"Items API time: {round(duration * 1000, 0)} ms")
95
  images = self.get_images(img_urls)
96
+ dedup_images = self._filter_dups(images)
 
 
 
 
97
  return dedup_images, domain
98
 
99
+ def _filter_dups(self, images: List):
100
+ if len(images) > 1:
101
+ hashes = {}
102
+ for img in images:
103
+ hashes.update({str(imagehash.average_hash(img)): img})
104
+ dedup_hashes = list(dict.fromkeys(hashes))
105
+ dedup_images = [img for hash, img in hashes.items() if hash in dedup_hashes]
106
+ else:
107
+ dedup_images = images
108
+ if (diff := len(images) - len(dedup_images)) > 0:
109
+ print(f"Filtered {diff} images out of {len(images)} due to matching hashes.")
110
+ return dedup_images
111
+
112
  def get_images(self, urls: List[str]):
113
  start = time.time()
114
  images = asyncio.run(self._gather_download_tasks(urls))
115
  end = time.time()
116
  duration = end - start
117
+ print(f"Download time: {round(duration, 2)} s",
118
+ f"Download time per image: {round(duration/len(urls) * 1000, 0)} ms",
119
  sep="\n")
120
  return asyncio.run(self._gather_download_tasks(urls))
121
 
 
145
  if save_images:
146
  with open(re.sub("D_NQ_NP_", "", img_url.split("/")[-1]) , "wb") as f:
147
  f.write(img.content)
148
+ return images, domain
149
+
150
+ def show(self, images: List[PIL.Image.Image], valid_probas: np.ndarray, n_cols: int = 3,
151
+ title: Optional[str] = None, threshold: Optional[float] = None):
152
+ if threshold is not None:
153
+ prediction = self.apply_threshold(valid_probas, threshold)
154
+ title_scores = [f"Valid: {pred.squeeze()}" for pred in prediction]
155
+ else:
156
+ prediction = np.round(valid_probas[:, 0], 2)
157
+ title_scores = [f"Valid: {pred:.2f}" for pred in prediction]
158
+ n_images = len(images)
159
+ n_rows = int(np.ceil(n_images / n_cols))
160
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 16))
161
+ for i, ax in enumerate(axes.ravel()):
162
+ ax.axis("off")
163
+ try:
164
+ ax.imshow(images[i])
165
+ ax.set_title(title_scores[i])
166
+ except IndexError:
167
+ continue
168
+ if title:
169
+ fig.suptitle(title)
170
+ fig.tight_layout()
171
+ return