File size: 3,859 Bytes
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
 
 
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
 
9e275b8
 
70399da
 
9e275b8
 
70399da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os.path

import torch

from Preprocessing.multilinguality.create_distance_lookups import CacheCreator
from Utility.utils import load_json_from_path


class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module):

    def __init__(self):
        super().__init__()
        cc = CacheCreator(cache_root="Preprocessing/multilinguality")
        if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
            cc.create_tree_cache(cache_root="Preprocessing/multilinguality")
        if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
            cc.create_map_cache(cache_root="Preprocessing/multilinguality")

        self.tree_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json')
        self.map_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_map_dist.json')
        # with open("Preprocessing/multilinguality/asp_dict.pkl", 'rb') as dictfile:
        #    self.asp_sim = pickle.load(dictfile)
        # self.lang_list = list(self.asp_sim.keys())  # list of all languages, to get lang_b's index

        self.largest_value_map_dist = 0.0
        for _, values in self.map_dist.items():
            for _, value in values.items():
                self.largest_value_map_dist = max(self.largest_value_map_dist, value)

        self.iso_codes_to_ids = load_json_from_path("Preprocessing/multilinguality/iso_lookup.json")[-1]
        self.ids_to_iso_codes = {v: k for k, v in self.iso_codes_to_ids.items()}

    def forward(self, language_ids, language_embeddings):
        """
        Args:
            language_ids (Tensor): IDs of languages in the same order as the embeddings to calculate the distances according to the metrics.
            language_embeddings (Tensor): Batch of language embeddings, of which the distances will be compared to the distances according to the metrics.

        Returns:
            Tensor: Language Embedding Structure Loss Value
        """

        losses = list()
        for language_id_1, language_embedding_1 in zip(language_ids, language_embeddings):
            for language_id_2, language_embedding_2 in zip(language_ids, language_embeddings):
                if language_id_1 != language_id_2:
                    embed_dist = torch.nn.functional.l1_loss(language_embedding_1, language_embedding_2)
                    lang_1 = self.ids_to_iso_codes[language_id_1]
                    lang_2 = self.ids_to_iso_codes[language_id_2]

                    # Value Range Normalized Tree Dist
                    try:
                        tree_dist = self.tree_dist[lang_1][lang_2]
                    except KeyError:
                        tree_dist = self.tree_dist[lang_2][lang_1]

                    # Value Range Normalized Map Dist
                    try:
                        map_dist = self.map_dist[lang_1][lang_2] / self.largest_value_map_dist
                    except KeyError:
                        map_dist = self.map_dist[lang_2][lang_1] / self.largest_value_map_dist

                    # Value Range Normalized ASP Dist
                    # lang_2_idx = self.lang_list.index(lang_2)
                    # asp_dist = 1.0 - self.asp_sim[lang_1][lang_2_idx]  # it's a similarity measure that goes from 0 to 1, so we subtract it from 1 to turn it into a distance

                    # Average distance should be similar to embedding distance to bring some structure into the embedding-space
                    # metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist) + torch.tensor(asp_dist)) / 3
                    metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist)) / 2
                    losses.append(torch.nn.functional.l1_loss(embed_dist, metric_distance))

        return sum(losses) / len(losses)