lpw commited on
Commit
da38675
1 Parent(s): 4b64e87

Create audio_pipe.py

Browse files
Files changed (1) hide show
  1. audio_pipe.py +161 -0
audio_pipe.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Tuple
5
+ import tempfile
6
+ import soundfile as sf
7
+ import gradio as gr
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchaudio
12
+ # from app.pipelines import Pipeline
13
+ from fairseq import hub_utils
14
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
15
+ from fairseq.models.speech_to_speech.hub_interface import S2SHubInterface
16
+ from fairseq.models.speech_to_text.hub_interface import S2THubInterface
17
+ from fairseq.models.text_to_speech import CodeHiFiGANVocoder
18
+ from fairseq.models.text_to_speech.hub_interface import (
19
+ TTSHubInterface,
20
+ VocoderHubInterface,
21
+ )
22
+ from huggingface_hub import snapshot_download
23
+
24
+ ARG_OVERRIDES_MAP = {
25
+ "facebook/xm_transformer_s2ut_800m-es-en-st-asr-bt_h1_2022": {
26
+ "config_yaml": "config.yaml",
27
+ "task": "speech_to_text",
28
+ }
29
+ }
30
+
31
+ class SpeechToSpeechPipeline():
32
+ def __init__(self, model_id: str):
33
+ arg_overrides = ARG_OVERRIDES_MAP.get(
34
+ model_id, {}
35
+ ) # Model specific override. TODO: Update on checkpoint side in the future
36
+ arg_overrides["config_yaml"] = "config.yaml" # common override
37
+ models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
38
+ model_id,
39
+ arg_overrides=arg_overrides,
40
+ cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
41
+ )
42
+ self.cfg = cfg
43
+ self.model = models[0].cpu()
44
+ self.model.eval()
45
+ self.task = task
46
+
47
+ self.sampling_rate = getattr(self.task, "sr", None) or 16_000
48
+
49
+ tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
50
+ pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""
51
+
52
+ generation_args = self.task.data_cfg.hub.get(f"{pfx}generation_args", None)
53
+ if generation_args is not None:
54
+ for key in generation_args:
55
+ setattr(cfg.generation, key, generation_args[key])
56
+ self.generator = task.build_generator([self.model], cfg.generation)
57
+
58
+ tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
59
+ self.unit_vocoder = self.task.data_cfg.hub.get(f"{pfx}unit_vocoder", None)
60
+ self.tts_model, self.tts_task, self.tts_generator = None, None, None
61
+ if tts_model_id is not None:
62
+ _id = tts_model_id.split(":")[-1]
63
+ cache_dir = os.getenv("HUGGINGFACE_HUB_CACHE")
64
+ if self.unit_vocoder is not None:
65
+ library_name = "fairseq"
66
+ cache_dir = (
67
+ cache_dir or (Path.home() / ".cache" / library_name).as_posix()
68
+ )
69
+ cache_dir = snapshot_download(
70
+ f"facebook/{_id}", cache_dir=cache_dir, library_name=library_name
71
+ )
72
+
73
+ x = hub_utils.from_pretrained(
74
+ cache_dir,
75
+ "model.pt",
76
+ ".",
77
+ archive_map=CodeHiFiGANVocoder.hub_models(),
78
+ config_yaml="config.json",
79
+ fp16=False,
80
+ is_vocoder=True,
81
+ )
82
+
83
+ with open(f"{x['args']['data']}/config.json") as f:
84
+ vocoder_cfg = json.load(f)
85
+ assert (
86
+ len(x["args"]["model_path"]) == 1
87
+ ), "Too many vocoder models in the input"
88
+
89
+ vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg)
90
+ self.tts_model = VocoderHubInterface(vocoder_cfg, vocoder)
91
+
92
+ else:
93
+ (
94
+ tts_models,
95
+ tts_cfg,
96
+ self.tts_task,
97
+ ) = load_model_ensemble_and_task_from_hf_hub(
98
+ f"facebook/{_id}",
99
+ arg_overrides={"vocoder": "griffin_lim", "fp16": False},
100
+ cache_dir=cache_dir,
101
+ )
102
+ self.tts_model = tts_models[0].cpu()
103
+ self.tts_model.eval()
104
+ tts_cfg["task"].cpu = True
105
+ TTSHubInterface.update_cfg_with_data_cfg(
106
+ tts_cfg, self.tts_task.data_cfg
107
+ )
108
+ self.tts_generator = self.tts_task.build_generator(
109
+ [self.tts_model], tts_cfg
110
+ )
111
+
112
+ def __call__(self, inputs: str) -> Tuple[np.array, int, List[str]]:
113
+ """
114
+ Args:
115
+ inputs (:obj:`np.array`):
116
+ The raw waveform of audio received. By default sampled at `self.sampling_rate`.
117
+ The shape of this array is `T`, where `T` is the time axis
118
+ Return:
119
+ A :obj:`tuple` containing:
120
+ - :obj:`np.array`:
121
+ The return shape of the array must be `C'`x`T'`
122
+ - a :obj:`int`: the sampling rate as an int in Hz.
123
+ - a :obj:`List[str]`: the annotation for each out channel.
124
+ This can be the name of the instruments for audio source separation
125
+ or some annotation for speech enhancement. The length must be `C'`.
126
+ """
127
+ # _inputs = torch.from_numpy(inputs).unsqueeze(0)
128
+ # print(f"input: {inputs}")
129
+ # _inputs = torchaudio.load(inputs)
130
+ _inputs = inputs
131
+ sample, text = None, None
132
+ if self.cfg.task._name in ["speech_to_text", "speech_to_text_sharded"]:
133
+ sample = S2THubInterface.get_model_input(self.task, _inputs)
134
+ text = S2THubInterface.get_prediction(
135
+ self.task, self.model, self.generator, sample
136
+ )
137
+ elif self.cfg.task._name in ["speech_to_speech"]:
138
+ s2shubinerface = S2SHubInterface(self.cfg, self.task, self.model)
139
+ sample = s2shubinerface.get_model_input(self.task, _inputs)
140
+ text = S2SHubInterface.get_prediction(
141
+ self.task, self.model, self.generator, sample
142
+ )
143
+
144
+ wav, sr = np.zeros((0,)), self.sampling_rate
145
+ if self.unit_vocoder is not None:
146
+ tts_sample = self.tts_model.get_model_input(text)
147
+ wav, sr = self.tts_model.get_prediction(tts_sample)
148
+ text = ""
149
+ else:
150
+ tts_sample = TTSHubInterface.get_model_input(self.tts_task, text)
151
+ wav, sr = TTSHubInterface.get_prediction(
152
+ self.tts_task, self.tts_model, self.tts_generator, tts_sample
153
+ )
154
+ temp_file = ""
155
+ with tempfile.NamedTemporaryFile(suffix=".wav") as tmp_output_file:
156
+ sf.write(tmp_output_file, wav.detach().cpu().numpy(), sr)
157
+ tmp_output_file.seek(0)
158
+ temp_file = gr.Audio(tmp_output_file.name)
159
+
160
+ # return wav, sr, [text]
161
+ return temp_file