File size: 3,263 Bytes
3e4f32c
 
 
 
 
 
 
 
 
 
 
 
149eeaf
3e4f32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149eeaf
 
 
 
 
 
 
 
 
3e4f32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

import os

import torch
from clip_transform import CLIPTransform
from PIL import Image

from torch.nn import functional as F

class Prototypes:
    def __init__(self):
        self._clip_transform = CLIPTransform()
        self._load_prototypes()

    def _prepare_prototypes(self):
        image_embeddings = self.load_images_from_folder('prototypes')
        assert image_embeddings is not None, "no image embeddings found"
        assert len(image_embeddings) > 0, "no image embeddings found"
        person_keys = [key for key in image_embeddings.keys() if key.startswith('person-')]
        no_person_keys = [key for key in image_embeddings.keys() if key.startswith('no_person-')]
        person_keys.sort()
        no_person_keys.sort()
        # create pytorch vector of person embeddings
        person_embeddings = torch.cat([image_embeddings[key] for key in person_keys])
        # create pytorch vector of no_person embeddings
        no_person_embeddings = torch.cat([image_embeddings[key] for key in no_person_keys])
        person_embedding = person_embeddings.mean(dim=0)
        person_embedding /= person_embedding.norm(dim=-1, keepdim=True)
        no_person_embedding = no_person_embeddings.mean(dim=0)
        no_person_embedding /= no_person_embedding.norm(dim=-1, keepdim=True)

        self.prototype_keys = ["person", "no_person"]
        self.prototypes = torch.stack([person_embedding, no_person_embedding])
        # save prototypes to file
        torch.save(self.prototypes, 'prototypes.pt')

    def _load_prototypes(self):
        # check if file exists
        if not os.path.exists('prototypes.pt'):
            self._prepare_prototypes()
        self.prototypes = torch.load('prototypes.pt')
        self.prototype_keys = ["person", "no_person"]


    def load_images_from_folder(self, folder):
        image_embeddings = {}
        supported_filetypes = ['.jpg','.png','.jpeg']
        for filename in os.listdir(folder):
            if not any([filename.endswith(ft) for ft in supported_filetypes]):
                continue
            image = Image.open(os.path.join(folder,filename))
            embeddings = self._clip_transform.pil_image_to_embeddings(image)
            image_embeddings[filename] = embeddings
        return image_embeddings
    
    def get_distances(self, embeddings):
        # case not normalized
        # distances = F.cosine_similarity(embeddings, self.prototypes)
        # case normalized
        distances = embeddings @ self.prototypes.T
        closest_item_idex = distances.argmax().item()
        closest_item_key = self.prototype_keys[closest_item_idex]
        debug_str = ""
        for key, value in zip(self.prototype_keys, distances):
            debug_str += f"{key}: {value.item():.2f}, "
        return distances, closest_item_key, debug_str
    

if __name__ == "__main__":
    prototypes = Prototypes()
    print ("prototypes:")
    for key, value in zip(prototypes.prototype_keys, prototypes.prototypes):
        print (f"{key}: {len(value)}")

    embeddings = prototypes.prototypes[0]
    distances, closest_item_key, debug_str = prototypes.get_distances(embeddings)
    print (f"closest_item_key: {closest_item_key}")
    print (f"distances: {debug_str}")
    print ("done")