theArijitDas commited on
Commit
ecc6de5
1 Parent(s): d29ec33

Update image_validator.py

Browse files
Files changed (1) hide show
  1. image_validator.py +62 -63
image_validator.py CHANGED
@@ -1,64 +1,63 @@
1
- from transformers import CLIPProcessor, CLIPModel, ViTImageProcessor, ViTModel
2
- from PIL import Image
3
- from sklearn.metrics.pairwise import cosine_similarity
4
-
5
- from warnings import filterwarnings
6
- filterwarnings("ignore")
7
-
8
- models = ["CLIP-ViT Base", "ViT Base", "DINO ViT-S16"]
9
- models_info = {
10
- "CLIP-ViT Base": {
11
- "model_size": "386MB",
12
- "model_url": "openai/clip-vit-base-patch32",
13
- "efficiency": "High",
14
- },
15
- "ViT Base": {
16
- "model_size": "304MB",
17
- "model_url": "google/vit-base-patch16-224",
18
- "efficiency": "High",
19
- },
20
- "DINO ViT-S16": {
21
- "model_size": "1.34GB",
22
- "model_url": "facebook/dino-vits16",
23
- "efficiency": "Moderate",
24
- },
25
- }
26
-
27
- class Image_Validator:
28
- def __init__(self, model_name=None):
29
- if model_name is None: model_name="ViT Base"
30
-
31
- self.model_info = models_info[model_name]
32
- model_url = self.model_info["model_url"]
33
-
34
- if model_name == "CLIP-ViT Base":
35
- self.model = CLIPModel.from_pretrained(model_url)
36
- self.processor = CLIPProcessor.from_pretrained(model_url)
37
-
38
- elif model_name == "ViT Base":
39
- self.model = ViTModel.from_pretrained(model_url)
40
- self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
41
-
42
- elif model_name == "DINO ViT-S16":
43
- self.model = ViTModel.from_pretrained(model_url)
44
- self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
45
-
46
- def get_image_embedding(self, image_path):
47
- image = Image.open(image_path)
48
-
49
- # Process image according to the model
50
- if hasattr(self, 'processor'): # CLIP models
51
- inputs = self.processor(images=image, return_tensors="pt")
52
- outputs = self.model.get_image_features(**inputs)
53
-
54
- elif hasattr(self, 'feature_extractor'): # ViT models
55
- inputs = self.feature_extractor(images=image, return_tensors="pt")
56
- outputs = self.model(**inputs).last_hidden_state
57
-
58
- return outputs
59
-
60
- def similarity_score(self, image_path_1, image_path_2):
61
- embedding1 = self.get_image_embedding(image_path_1).reshape(1, -1)
62
- embedding2 = self.get_image_embedding(image_path_2).reshape(1, -1)
63
- similarity = cosine_similarity(embedding1.detach().numpy(), embedding2.detach().numpy())
64
  return similarity[0][0]
 
1
+ from transformers import CLIPProcessor, CLIPModel, ViTImageProcessor, ViTModel
2
+ from PIL import Image
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ from warnings import filterwarnings
6
+ filterwarnings("ignore")
7
+
8
+ models = ["CLIP-ViT Base", "ViT Base", "DINO ViT-S16"]
9
+ models_info = {
10
+ "CLIP-ViT Base": {
11
+ "model_size": "386MB",
12
+ "model_url": "openai/clip-vit-base-patch32",
13
+ "efficiency": "High",
14
+ },
15
+ "ViT Base": {
16
+ "model_size": "304MB",
17
+ "model_url": "google/vit-base-patch16-224",
18
+ "efficiency": "High",
19
+ },
20
+ "DINO ViT-S16": {
21
+ "model_size": "1.34GB",
22
+ "model_url": "facebook/dino-vits16",
23
+ "efficiency": "Moderate",
24
+ },
25
+ }
26
+
27
+ class Image_Validator:
28
+ def __init__(self, model_name=None):
29
+ if model_name is None: model_name="ViT Base"
30
+
31
+ self.model_info = models_info[model_name]
32
+ model_url = self.model_info["model_url"]
33
+
34
+ if model_name == "CLIP-ViT Base":
35
+ self.model = CLIPModel.from_pretrained(model_url)
36
+ self.processor = CLIPProcessor.from_pretrained(model_url)
37
+
38
+ elif model_name == "ViT Base":
39
+ self.model = ViTModel.from_pretrained(model_url)
40
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
41
+
42
+ elif model_name == "DINO ViT-S16":
43
+ self.model = ViTModel.from_pretrained(model_url)
44
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
45
+
46
+ def get_image_embedding(self, image):
47
+
48
+ # Process image according to the model
49
+ if hasattr(self, 'processor'): # CLIP models
50
+ inputs = self.processor(images=image, return_tensors="pt")
51
+ outputs = self.model.get_image_features(**inputs)
52
+
53
+ elif hasattr(self, 'feature_extractor'): # ViT models
54
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
55
+ outputs = self.model(**inputs).last_hidden_state
56
+
57
+ return outputs
58
+
59
+ def similarity_score(self, image1, image2):
60
+ embedding1 = self.get_image_embedding(image1).reshape(1, -1)
61
+ embedding2 = self.get_image_embedding(image2).reshape(1, -1)
62
+ similarity = cosine_similarity(embedding1.detach().numpy(), embedding2.detach().numpy())
 
63
  return similarity[0][0]