hahunavth commited on
Commit
363c289
1 Parent(s): 6721c76

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +116 -0
handler.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import speechbrain as sb
3
+
4
+ show_results_every = 100 # plots results every N iterations
5
+ run_opts = {
6
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
7
+ }
8
+
9
+ class PipelineSLUTask(sb.pretrained.interfaces.Pretrained):
10
+ HPARAMS_NEEDED = [
11
+ "slu_enc",
12
+ "output_emb",
13
+ "dec",
14
+ "seq_lin",
15
+ "env_corrupt",
16
+ ]
17
+ MODULES_NEEDED = [
18
+ "slu_enc",
19
+ "output_emb",
20
+ "dec",
21
+ "seq_lin",
22
+ "env_corrupt",
23
+ ]
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ pass
28
+
29
+ def encode_file(self, path):
30
+
31
+ tokens_bos = torch.tensor([[0]]).to(self.device)
32
+ tokens = torch.tensor([], dtype=torch.int64).to(self.device)
33
+
34
+ waveform = self.load_audio(path)
35
+ wavs = waveform.unsqueeze(0)
36
+ wav_lens = torch.tensor([1.0])
37
+ # Fake a batch:
38
+ # batch = waveform.unsqueeze(0)
39
+ rel_length = torch.tensor([1.0])
40
+ with torch.no_grad():
41
+ rel_lens = rel_length.to(self.device)
42
+ # ASR encoder forward pass
43
+ ASR_encoder_out = self.hparams.asr_model.encode_batch(
44
+ wavs.detach(), wav_lens
45
+ )
46
+
47
+ # SLU forward pass
48
+ encoder_out = self.hparams.slu_enc(ASR_encoder_out)
49
+ e_in = self.hparams.output_emb(tokens_bos)
50
+ # print(e_in.shape)
51
+ # print(encoder_out.shape)
52
+ # print(wav_lens.shape)
53
+ h, _ = self.hparams.dec(e_in, encoder_out, wav_lens)
54
+
55
+ # Output layer for seq2seq log-probabilities
56
+ logits = self.hparams.seq_lin(h)
57
+ p_seq = self.hparams.log_softmax(logits)
58
+
59
+ # Compute outputs
60
+ # if (
61
+ # stage == sb.Stage.TRAIN
62
+ # and self.batch_count % show_results_every != 0
63
+ # ):
64
+ # return p_seq, wav_lens
65
+ # else:
66
+ p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
67
+ return p_seq, wav_lens, p_tokens
68
+
69
+ # return ASR_encoder_out
70
+
71
+ def decode(self, p_seq, wav_lens, predicted_tokens):
72
+ tokens_eos = torch.tensor([[0]]).to(self.device)
73
+ tokens_eos_lens = torch.tensor([0]).to(self.device)
74
+
75
+ # Decode token terms to words
76
+ predicted_semantics = [
77
+ tokenizer.decode_ids(utt_seq).split(" ")
78
+ for utt_seq in predicted_tokens
79
+ ]
80
+ return predicted_semantics
81
+
82
+
83
+ from typing import Dict, List, Any
84
+ from pretrained import PipelineSLUTask
85
+
86
+ class EndpointHandler():
87
+ def __init__(self, path=""):
88
+ hparams_file = f"{path}/direct-train.yaml"
89
+ overrides = {}
90
+ with open(hparams_file) as fin:
91
+ hparams = load_hyperpyyaml(fin, overrides)
92
+
93
+ run_opts = {
94
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
95
+ }
96
+
97
+ self.pipeline = PipelineSLUTask(
98
+ modules=hparams['modules'],
99
+ hparams=hparams,
100
+ run_opts=run_opts
101
+ )
102
+
103
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
104
+ """
105
+ data args:
106
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
107
+ kwargs
108
+ Return:
109
+ A :obj:`list` | `dict`: will be serialized and returned
110
+ """
111
+ # pseudo
112
+ # self.model(input)
113
+ data = data.get("inputs", data)
114
+ print(data)
115
+ ps, wl, pt = self.pipeline.encode_file(data)
116
+ return self.pipeline.decode(ps, wl, pt)