File size: 3,619 Bytes
e775f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49cf77e
 
e775f6d
004e4fc
 
e775f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from src.music.utilities.representation_learning_utilities.constants import *
from src.music.config import REP_MODEL_NAME
from src.music.utils import  get_out_path
import pickle
import numpy as np
# from transformers import AutoModel, AutoTokenizer
from torch import nn
from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer

class Argument(object):
    def __init__(self, adict):
        self.__dict__.update(adict)

class RepModel(nn.Module):
    def __init__(self, model, model_name):
        super().__init__()
        if 't5' in model_name:
            self.model = model.get_encoder()
        else:
            self.model = model
        self.model.eval()

    def forward(self, inputs):
        with torch.no_grad():
            out = self.model(inputs, output_hidden_states=True)
        embeddings = out.hidden_states[-1]
        return torch.mean(embeddings[0], dim=0)

# def get_trained_music_LM(model_name):
#     tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
#     model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name)
#
#     return model, tokenizer

def get_trained_sentence_embedder(model_name):
    model = SentenceTransformer(model_name)
    return model

MODEL = get_trained_sentence_embedder(REP_MODEL_NAME)

def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0):
    if not rep_path:
        rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt')

    error_msg = 'Error in music transformer mapping.'
    if verbose: print(' ' * level + 'Mapping to final music representations')
    try:
        error_msg += ' Error in encoded file loading?'
        with open(encoded_path, 'rb') as f:
            data = pickle.load(f)
        performance = [str(w) for w in data['main'] if w != 1]
        assert len(performance) % 5 == 0
        if(len(performance) == 0):
            error_msg += " Error: No midi messages in primer file"
            assert False
        error_msg += ' Nope, error in tokenization?'
        perf = ' '.join(performance)
        # tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0)
        error_msg += ' Nope. Maybe in performance encoding?'
        # reps = []
        # for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)):
        #     chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2]
        #     rep = MODEL(chunk_tokenized)
        #     reps.append(rep.detach().numpy())
        # representation = np.mean(reps, axis=0)
        p = [int(p) for p in perf.split(' ')]
        print('PERF:', np.sum(p), perf)
        representation = MODEL.encode(perf)
        print('model weights sum: ', torch.sum(torch.Tensor([param.sum() for param in list(MODEL.parameters())])))
        print('reprep', representation)
        error_msg += ' Nope. Saving performance?'
        np.savetxt(rep_path, representation)
        error_msg += ' Nope.'
        if verbose: print(' ' * (level + 2) + 'Success.')
        if return_rep:
            return rep_path, representation, ''
        else:
            return rep_path, ''
    except:
        if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}')
        if return_rep:
            return None, None, error_msg
        else:
            return None, error_msg

if __name__ == "__main__":
    representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle")
    stop = 1