mrfakename commited on
Commit
06cc563
1 Parent(s): 92e1714

Create cog.py

Browse files
Files changed (1) hide show
  1. cog.py +181 -0
cog.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+
6
+ import os
7
+ import re
8
+ import torch
9
+ import torchaudio
10
+ import gradio as gr
11
+ import numpy as np
12
+ import tempfile
13
+ from einops import rearrange
14
+ from ema_pytorch import EMA
15
+ from vocos import Vocos
16
+ from pydub import AudioSegment
17
+ from model import CFM, UNetT, DiT, MMDiT
18
+ from cached_path import cached_path
19
+ from model.utils import (
20
+ get_tokenizer,
21
+ convert_char_to_pinyin,
22
+ save_spectrogram,
23
+ )
24
+ from transformers import pipeline
25
+ import librosa
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
28
+
29
+ target_sample_rate = 24000
30
+ n_mel_channels = 100
31
+ hop_length = 256
32
+ target_rms = 0.1
33
+ nfe_step = 32 # 16, 32
34
+ cfg_strength = 2.0
35
+ ode_method = 'euler'
36
+ sway_sampling_coef = -1.0
37
+ speed = 1.0
38
+ # fix_duration = 27 # None or float (duration in seconds)
39
+ fix_duration = None
40
+
41
+
42
+ class Predictor(BasePredictor):
43
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
44
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
45
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
46
+ model = CFM(
47
+ transformer=model_cls(
48
+ **model_cfg,
49
+ text_num_embeds=vocab_size,
50
+ mel_dim=n_mel_channels
51
+ ),
52
+ mel_spec_kwargs=dict(
53
+ target_sample_rate=target_sample_rate,
54
+ n_mel_channels=n_mel_channels,
55
+ hop_length=hop_length,
56
+ ),
57
+ odeint_kwargs=dict(
58
+ method=ode_method,
59
+ ),
60
+ vocab_char_map=vocab_char_map,
61
+ ).to(device)
62
+
63
+ ema_model = EMA(model, include_online_model=False).to(device)
64
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
65
+ ema_model.copy_params_from_ema_to_model()
66
+
67
+ return ema_model, model
68
+ def setup(self) -> None:
69
+ """Load the model into memory to make running multiple predictions efficient"""
70
+ # self.model = torch.load("./weights.pth")
71
+ print("Loading Whisper model...")
72
+ self.pipe = pipeline(
73
+ "automatic-speech-recognition",
74
+ model="openai/whisper-large-v3-turbo",
75
+ torch_dtype=torch.float16,
76
+ device=device,
77
+ )
78
+ print("Loading F5-TTS model...")
79
+
80
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
81
+ self.F5TTS_ema_model, self.F5TTS_base_model = self.load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
82
+
83
+
84
+ def predict(
85
+ self,
86
+ gen_text: str = Input(description="Text to generate"),
87
+ ref_audio_orig: Path = Input(description="Reference audio"),
88
+ remove_silence: bool = Input(description="Remove silences", default=True),
89
+ ) -> Path:
90
+ """Run a single prediction on the model"""
91
+ model_choice = "F5-TTS"
92
+ print(gen_text)
93
+ if len(gen_text) > 200:
94
+ raise gr.Error("Please keep your text under 200 chars.")
95
+ gr.Info("Converting audio...")
96
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
97
+ aseg = AudioSegment.from_file(ref_audio_orig)
98
+ audio_duration = len(aseg)
99
+ if audio_duration > 15000:
100
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
101
+ aseg = aseg[:15000]
102
+ aseg.export(f.name, format="wav")
103
+ ref_audio = f.name
104
+ ema_model = self.F5TTS_ema_model
105
+ base_model = self.F5TTS_base_model
106
+
107
+ if not ref_text.strip():
108
+ gr.Info("No reference text provided, transcribing reference audio...")
109
+ ref_text = outputs = self.pipe(
110
+ ref_audio,
111
+ chunk_length_s=30,
112
+ batch_size=128,
113
+ generate_kwargs={"task": "transcribe"},
114
+ return_timestamps=False,
115
+ )['text'].strip()
116
+ gr.Info("Finished transcription")
117
+ else:
118
+ gr.Info("Using custom reference text...")
119
+ audio, sr = torchaudio.load(ref_audio)
120
+
121
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
122
+ if rms < target_rms:
123
+ audio = audio * target_rms / rms
124
+ if sr != target_sample_rate:
125
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126
+ audio = resampler(audio)
127
+ audio = audio.to(device)
128
+
129
+ # Prepare the text
130
+ text_list = [ref_text + gen_text]
131
+ final_text_list = convert_char_to_pinyin(text_list)
132
+
133
+ # Calculate duration
134
+ ref_audio_len = audio.shape[-1] // hop_length
135
+ # if fix_duration is not None:
136
+ # duration = int(fix_duration * target_sample_rate / hop_length)
137
+ # else:
138
+ zh_pause_punc = r"。,、;:?!"
139
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
140
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
141
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
142
+
143
+ # inference
144
+ gr.Info(f"Generating audio using F5-TTS")
145
+ with torch.inference_mode():
146
+ generated, _ = base_model.sample(
147
+ cond=audio,
148
+ text=final_text_list,
149
+ duration=duration,
150
+ steps=nfe_step,
151
+ cfg_strength=cfg_strength,
152
+ sway_sampling_coef=sway_sampling_coef,
153
+ )
154
+
155
+ generated = generated[:, ref_audio_len:, :]
156
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
157
+ gr.Info("Running vocoder")
158
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
159
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
160
+ if rms < target_rms:
161
+ generated_wave = generated_wave * rms / target_rms
162
+
163
+ # wav -> numpy
164
+ generated_wave = generated_wave.squeeze().cpu().numpy()
165
+
166
+ if remove_silence:
167
+ gr.Info("Removing audio silences... This may take a moment")
168
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
169
+ non_silent_wave = np.array([])
170
+ for interval in non_silent_intervals:
171
+ start, end = interval
172
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
173
+ generated_wave = non_silent_wave
174
+
175
+
176
+ # spectogram
177
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
178
+ wav_path = tmp_wav.name
179
+ torchaudio.save(wav_path, torch.tensor(generated_wave), target_sample_rate)
180
+
181
+ return wav_path