Spaces:
Running
on
T4
Running
on
T4
File size: 3,859 Bytes
9e275b8 70399da 9e275b8 70399da 9e275b8 70399da 9e275b8 70399da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os.path
import torch
from Preprocessing.multilinguality.create_distance_lookups import CacheCreator
from Utility.utils import load_json_from_path
class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module):
def __init__(self):
super().__init__()
cc = CacheCreator(cache_root="Preprocessing/multilinguality")
if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
cc.create_tree_cache(cache_root="Preprocessing/multilinguality")
if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
cc.create_map_cache(cache_root="Preprocessing/multilinguality")
self.tree_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json')
self.map_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_map_dist.json')
# with open("Preprocessing/multilinguality/asp_dict.pkl", 'rb') as dictfile:
# self.asp_sim = pickle.load(dictfile)
# self.lang_list = list(self.asp_sim.keys()) # list of all languages, to get lang_b's index
self.largest_value_map_dist = 0.0
for _, values in self.map_dist.items():
for _, value in values.items():
self.largest_value_map_dist = max(self.largest_value_map_dist, value)
self.iso_codes_to_ids = load_json_from_path("Preprocessing/multilinguality/iso_lookup.json")[-1]
self.ids_to_iso_codes = {v: k for k, v in self.iso_codes_to_ids.items()}
def forward(self, language_ids, language_embeddings):
"""
Args:
language_ids (Tensor): IDs of languages in the same order as the embeddings to calculate the distances according to the metrics.
language_embeddings (Tensor): Batch of language embeddings, of which the distances will be compared to the distances according to the metrics.
Returns:
Tensor: Language Embedding Structure Loss Value
"""
losses = list()
for language_id_1, language_embedding_1 in zip(language_ids, language_embeddings):
for language_id_2, language_embedding_2 in zip(language_ids, language_embeddings):
if language_id_1 != language_id_2:
embed_dist = torch.nn.functional.l1_loss(language_embedding_1, language_embedding_2)
lang_1 = self.ids_to_iso_codes[language_id_1]
lang_2 = self.ids_to_iso_codes[language_id_2]
# Value Range Normalized Tree Dist
try:
tree_dist = self.tree_dist[lang_1][lang_2]
except KeyError:
tree_dist = self.tree_dist[lang_2][lang_1]
# Value Range Normalized Map Dist
try:
map_dist = self.map_dist[lang_1][lang_2] / self.largest_value_map_dist
except KeyError:
map_dist = self.map_dist[lang_2][lang_1] / self.largest_value_map_dist
# Value Range Normalized ASP Dist
# lang_2_idx = self.lang_list.index(lang_2)
# asp_dist = 1.0 - self.asp_sim[lang_1][lang_2_idx] # it's a similarity measure that goes from 0 to 1, so we subtract it from 1 to turn it into a distance
# Average distance should be similar to embedding distance to bring some structure into the embedding-space
# metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist) + torch.tensor(asp_dist)) / 3
metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist)) / 2
losses.append(torch.nn.functional.l1_loss(embed_dist, metric_distance))
return sum(losses) / len(losses)
|