File size: 3,593 Bytes
58942ad
363c289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58942ad
363c289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

import torch
import speechbrain as sb

show_results_every = 100  # plots results every N iterations
run_opts = {
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

class PipelineSLUTask(sb.pretrained.interfaces.Pretrained):
    HPARAMS_NEEDED = [
            "slu_enc",
            "output_emb",
            "dec",
            "seq_lin",
            "env_corrupt",
    ]
    MODULES_NEEDED = [
            "slu_enc",
            "output_emb",
            "dec",
            "seq_lin",
            "env_corrupt",
    ]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        pass

    def encode_file(self, path):

        tokens_bos = torch.tensor([[0]]).to(self.device)
        tokens = torch.tensor([], dtype=torch.int64).to(self.device)

        waveform = self.load_audio(path)
        wavs = waveform.unsqueeze(0)
        wav_lens = torch.tensor([1.0])
        # Fake a batch:
        # batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        with torch.no_grad():
            rel_lens = rel_length.to(self.device)
            # ASR encoder forward pass
            ASR_encoder_out = self.hparams.asr_model.encode_batch(
                wavs.detach(), wav_lens
            )

            # SLU forward pass
            encoder_out = self.hparams.slu_enc(ASR_encoder_out)
            e_in = self.hparams.output_emb(tokens_bos)
            # print(e_in.shape)
            # print(encoder_out.shape)
            # print(wav_lens.shape)
            h, _ = self.hparams.dec(e_in, encoder_out, wav_lens)

            # Output layer for seq2seq log-probabilities
            logits = self.hparams.seq_lin(h)
            p_seq = self.hparams.log_softmax(logits)

            # Compute outputs
            # if (
            #     stage == sb.Stage.TRAIN
            #     and self.batch_count % show_results_every != 0
            # ):
            #     return p_seq, wav_lens
            # else:
            p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
            return p_seq, wav_lens, p_tokens

        # return ASR_encoder_out

    def decode(self, p_seq, wav_lens, predicted_tokens):
        tokens_eos = torch.tensor([[0]]).to(self.device)
        tokens_eos_lens = torch.tensor([0]).to(self.device)

        # Decode token terms to words
        predicted_semantics = [
            tokenizer.decode_ids(utt_seq).split(" ")
            for utt_seq in predicted_tokens
        ]
        return predicted_semantics


from typing import Dict, List, Any
from pretrained import PipelineSLUTask

class EndpointHandler():
    def __init__(self, path=""):
        hparams_file = f"{path}/better_tokenizer/1986/direct-train.yaml"
        overrides = {}
        with open(hparams_file) as fin:
            hparams = load_hyperpyyaml(fin, overrides)
            
        run_opts = {
            "device": "cuda" if torch.cuda.is_available() else "cpu"
        }
        
        self.pipeline = PipelineSLUTask(
            modules=hparams['modules'],
            hparams=hparams,
            run_opts=run_opts
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # pseudo
        # self.model(input)
        data = data.get("inputs", data)
        print(data)
        ps, wl, pt = self.pipeline.encode_file(data)
        return self.pipeline.decode(ps, wl, pt)