Spaces:
Running
on
T4
Running
on
T4
from pathlib import Path | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from fairseq2.assets.card import AssetCard | |
from fairseq2.data import Collater | |
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter | |
from fairseq2.data.text.text_tokenizer import TextTokenizer | |
from fairseq2.data.typing import StringLike | |
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions | |
from fairseq2.memory import MemoryBlock | |
from fairseq2.typing import DataType, Device | |
from torch import Tensor | |
from enum import Enum, auto | |
from seamless_communication.models.inference.ngram_repeat_block_processor import ( | |
NGramRepeatBlockProcessor, | |
) | |
from seamless_communication.models.unity import ( | |
UnitTokenizer, | |
UnitYGenerator, | |
UnitYModel, | |
load_unity_model, | |
load_unity_text_tokenizer, | |
load_unity_unit_tokenizer, | |
) | |
from seamless_communication.models.unity.generator import SequenceToUnitOutput | |
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder | |
# from seamless_communication.models.streaming.agents import ( | |
# SileroVADAgent, | |
# TestTimeWaitKS2TVAD, | |
# TestTimeWaitKUnityV1M4T | |
# ) | |
from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import ( | |
TestTimeWaitKUnityS2TM4T, | |
) | |
from seamless_communication.cli.streaming.dataloader import Fairseq2SpeechToTextDataloader | |
### From test_pipeline | |
import math | |
import soundfile | |
from argparse import Namespace, ArgumentParser | |
from simuleval.data.segments import SpeechSegment, EmptySegment | |
from simuleval.utils import build_system_from_dir | |
from pathlib import Path | |
import numpy as np | |
class AudioFrontEnd: | |
def __init__(self, wav_file, segment_size) -> None: | |
self.samples, self.sample_rate = soundfile.read(wav_file) | |
# print(len(self.samples), self.samples[:100]) | |
self.samples = self.samples.tolist() | |
self.segment_size = segment_size | |
self.step = 0 | |
def send_segment(self): | |
""" | |
This is the front-end logic in simuleval instance.py | |
""" | |
num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate) | |
print("self.segment_size", self.segment_size) | |
print('num_samples is', num_samples) | |
print('self.sample_rate is', self.sample_rate) | |
if self.step < len(self.samples): | |
if self.step + num_samples >= len(self.samples): | |
samples = self.samples[self.step :] | |
is_finished = True | |
else: | |
samples = self.samples[self.step : self.step + num_samples] | |
is_finished = False | |
self.step = min(self.step + num_samples, len(self.samples)) | |
# print("len(samples) is", len(samples)) | |
# import pdb | |
# pdb.set_trace() | |
segment = SpeechSegment( | |
index=self.step / self.sample_rate * 1000, | |
content=samples, | |
sample_rate=self.sample_rate, | |
finished=is_finished, | |
) | |
else: | |
# Finish reading this audio | |
segment = EmptySegment( | |
index=self.step / self.sample_rate * 1000, | |
finished=True, | |
) | |
return segment | |
def load_model_for_inference( | |
load_model_fn: Callable[..., nn.Module], | |
model_name_or_card: Union[str, AssetCard], | |
device: Device, | |
dtype: DataType, | |
) -> nn.Module: | |
model = load_model_fn(model_name_or_card, device=device, dtype=dtype) | |
model.eval() | |
return model | |
def load_model_fairseq2(): | |
data_configs = dict( | |
dataloader="fairseq2_s2t", | |
data_file="/large_experiments/seamless/ust/abinesh/data/s2st50_manifests/50-10/simuleval/dev_mtedx_filt_50-10_debug.tsv", | |
) | |
model_configs = dict( | |
model_name="seamlessM4T_v2_large", | |
device="cuda:0", | |
source_segment_size=320, | |
waitk_lagging=7, | |
fixed_pre_decision_ratio=2, | |
init_target_tokens="</s> __eng__", | |
max_len_a=0, | |
max_len_b=200, | |
agent_class="seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t.TestTimeWaitKUnityS2TM4T", | |
task="s2st", | |
tgt_lang="eng", | |
) | |
eval_configs = dict( | |
latency_metrics="StartOffset EndOffset AL", | |
output=f"{TestTimeWaitKUnityS2TM4T.__name__}-wait{model_configs['waitk_lagging']}-debug", | |
) | |
model = TestTimeWaitKUnityS2TM4T({**data_configs, **model_configs, **eval_configs}) | |
print("model", model) | |
evaluate( | |
TestTimeWaitKUnityS2TM4T, {**data_configs, **model_configs, **eval_configs} | |
) | |
class SimulevalTranscoder: | |
# def __init__(self, agent, sample_rate, debug, buffer_limit): | |
def __init__(self): | |
# print("MDUPPES in here", SileroVADAgent, TestTimeWaitKS2TVAD) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
load_model_fairseq2() | |
device = "cpu" | |
print("DEVICE", device) | |
model_name_or_card="seamlessM4T_medium" | |
vocoder_name_or_card="vocoder_36langs" | |
# dtype=torch.float16, | |
# For CPU Mode need to use 32, float16 causes errors downstream | |
dtype=dtype=torch.float32 | |
model: UnitYModel = load_model_for_inference( | |
load_unity_model, model_name_or_card, device, dtype | |
) | |
print(model, type(model)) | |
parser = ArgumentParser() | |
source_segment_size = 320 # milliseconds | |
audio_frontend = AudioFrontEnd( | |
wav_file="/checkpoint/mduppes/samples/marta.wav", | |
segment_size=source_segment_size, | |
) | |
# mostly taken from S2S first agent: OnlineFeatureExtractorAgent defaults | |
SHIFT_SIZE = 10 | |
WINDOW_SIZE = 25 | |
SAMPLE_RATE = 16000 | |
FEATURE_DIM = 80 | |
# args and convert to namespace so it can be accesed via . | |
args = { | |
"shift_size": SHIFT_SIZE, | |
"window_size": WINDOW_SIZE, | |
"sample_rate": audio_frontend.sample_rate, | |
"feature_dim": 160, # from Wav2Vec2Frontend | |
"denormalize": False, # not sure.. | |
"global_stats": None, # default file path containing cmvn stats.. | |
} | |
print(args) | |
args = Namespace(**args) | |
pipeline = TestTimeWaitKUnityV1M4T(model, args) | |
system_states = pipeline.build_states() | |
print('system states:') | |
for state in system_states: | |
print(state, vars(state)) | |
input_segment = np.empty(0, dtype=np.int16) | |
segments = [] | |
while True: | |
speech_segment = audio_frontend.send_segment() | |
input_segment = np.concatenate((input_segment, np.array(speech_segment.content))) | |
# Translation happens here | |
output_segment = pipeline.pushpop(speech_segment, system_states) | |
print('pushpop result') | |
print(output_segment) | |
print('system states after pushpop:') | |
for state in system_states: | |
print(state, vars(state)) | |
if output_segment.finished: | |
segments.append(input_segment) | |
input_segment = np.empty(0, dtype=np.int16) | |
print("Resetting states") | |
for state in system_states: | |
state.reset() | |
if speech_segment.finished: | |
break | |
# The VAD-segmented samples from the full input audio | |
for i, seg in enumerate(segments): | |
with soundfile.SoundFile( | |
Path("/checkpoint/mduppes/samples") / f"marta_{i}.wav", | |
mode="w+", | |
format="WAV", | |
samplerate=16000, | |
channels=1, | |
) as f: | |
f.seek(0, soundfile.SEEK_END) | |
f.write(seg) | |