Spaces:
Sleeping
Sleeping
gpt-omni
commited on
Commit
•
411819d
1
Parent(s):
9b186d7
update
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ import soundfile as sf
|
|
30 |
from litgpt.model import GPT, Config
|
31 |
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
32 |
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
33 |
-
from utils.snac_utils import get_snac
|
34 |
import whisper
|
35 |
from tqdm import tqdm
|
36 |
from huggingface_hub import snapshot_download
|
@@ -80,19 +80,19 @@ if not os.path.exists(ckpt_dir):
|
|
80 |
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
81 |
whispermodel = whisper.load_model("small").to(device)
|
82 |
text_tokenizer = Tokenizer(ckpt_dir)
|
83 |
-
fabric = L.Fabric(devices=1, strategy="auto")
|
84 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
85 |
config.post_adapter = False
|
86 |
|
87 |
model = GPT(config, device=device)
|
88 |
|
89 |
-
# model = fabric.setup(model)
|
90 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
91 |
model.load_state_dict(state_dict, strict=True)
|
92 |
model = model.to(device)
|
93 |
model.eval()
|
94 |
|
95 |
|
|
|
96 |
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
97 |
with torch.no_grad():
|
98 |
mel = mel.unsqueeze(0).to(device)
|
@@ -128,6 +128,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
|
128 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
129 |
|
130 |
|
|
|
131 |
def next_token_batch(
|
132 |
model: GPT,
|
133 |
audio_features: torch.tensor,
|
@@ -162,9 +163,19 @@ def load_audio(path):
|
|
162 |
mel = whisper.log_mel_spectrogram(audio)
|
163 |
return mel, int(duration_ms / 20) + 1
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
# @torch.inference_mode()
|
167 |
-
@spaces.GPU
|
168 |
def run_AT_batch_stream(
|
169 |
audio_path,
|
170 |
stream_stride=4,
|
@@ -178,11 +189,10 @@ def run_AT_batch_stream(
|
|
178 |
|
179 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
180 |
|
181 |
-
# with self.fabric.init_tensor():
|
182 |
model.set_kv_cache(batch_size=2)
|
183 |
|
184 |
mel, leng = load_audio(audio_path)
|
185 |
-
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng,
|
186 |
T = input_ids[0].size(1)
|
187 |
device = input_ids[0].device
|
188 |
|
|
|
30 |
from litgpt.model import GPT, Config
|
31 |
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
32 |
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
33 |
+
from utils.snac_utils import get_snac
|
34 |
import whisper
|
35 |
from tqdm import tqdm
|
36 |
from huggingface_hub import snapshot_download
|
|
|
80 |
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
81 |
whispermodel = whisper.load_model("small").to(device)
|
82 |
text_tokenizer = Tokenizer(ckpt_dir)
|
83 |
+
# fabric = L.Fabric(devices=1, strategy="auto")
|
84 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
85 |
config.post_adapter = False
|
86 |
|
87 |
model = GPT(config, device=device)
|
88 |
|
|
|
89 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
90 |
model.load_state_dict(state_dict, strict=True)
|
91 |
model = model.to(device)
|
92 |
model.eval()
|
93 |
|
94 |
|
95 |
+
@spaces.GPU
|
96 |
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
97 |
with torch.no_grad():
|
98 |
mel = mel.unsqueeze(0).to(device)
|
|
|
128 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
129 |
|
130 |
|
131 |
+
@spaces.GPU
|
132 |
def next_token_batch(
|
133 |
model: GPT,
|
134 |
audio_features: torch.tensor,
|
|
|
163 |
mel = whisper.log_mel_spectrogram(audio)
|
164 |
return mel, int(duration_ms / 20) + 1
|
165 |
|
166 |
+
|
167 |
+
@spaces.GPU
|
168 |
+
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
169 |
+
audio = reconstruct_tensors(snac_tokens, device)
|
170 |
+
with torch.inference_mode():
|
171 |
+
audio_hat = snacmodel.decode(audio)
|
172 |
+
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|
173 |
+
audio_data = audio_data.astype(np.int16)
|
174 |
+
audio_data = audio_data.tobytes()
|
175 |
+
return audio_data
|
176 |
+
|
177 |
|
178 |
# @torch.inference_mode()
|
|
|
179 |
def run_AT_batch_stream(
|
180 |
audio_path,
|
181 |
stream_stride=4,
|
|
|
189 |
|
190 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
191 |
|
|
|
192 |
model.set_kv_cache(batch_size=2)
|
193 |
|
194 |
mel, leng = load_audio(audio_path)
|
195 |
+
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
|
196 |
T = input_ids[0].size(1)
|
197 |
device = input_ids[0].device
|
198 |
|