Spaces:
Running
on
Zero
Running
on
Zero
""" | |
File: utils.py | |
Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov | |
Description: Utility functions. | |
License: MIT License | |
""" | |
import time | |
import torch | |
import os | |
import subprocess | |
import bisect | |
import re | |
import requests | |
from torchvision import transforms | |
from PIL import Image | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from pathlib import Path | |
from contextlib import suppress | |
from urllib.parse import urlparse | |
from contextlib import ContextDecorator | |
from typing import Callable | |
class Timer(ContextDecorator): | |
"""Context manager for measuring code execution time""" | |
def __enter__(self): | |
self.start = time.time() | |
return self | |
def __exit__(self, *args): | |
self.end = time.time() | |
self.execution_time = f"{self.end - self.start:.2f} seconds" | |
def __str__(self): | |
return self.execution_time | |
def load_model( | |
model_url: str, folder_path: str, force_reload: bool = False | |
) -> str | None: | |
file_name = Path(urlparse(model_url).path).name | |
file_path = Path(folder_path) / file_name | |
if file_path.exists() and not force_reload: | |
return str(file_path) | |
with suppress(Exception), requests.get(model_url, stream=True) as response: | |
file_path.parent.mkdir(parents=True, exist_ok=True) | |
with file_path.open("wb") as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
return str(file_path) | |
return None | |
def readetect_speech( | |
file_path: str, | |
read_audio: Callable, | |
get_speech_timestamps: Callable, | |
vad_model: torch.jit.ScriptModule, | |
sr: int = 16000, | |
) -> list[dict]: | |
wav = read_audio(file_path, sampling_rate=sr) | |
# get speech timestamps from full audio file | |
speech_timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sr) | |
return wav, speech_timestamps | |
def calculate_mode(series): | |
mode = series.mode() | |
return mode[0] if not mode.empty else None | |
def pth_processing(fp): | |
class PreprocessInput(torch.nn.Module): | |
def init(self): | |
super(PreprocessInput, self).init() | |
def forward(self, x): | |
x = x.to(torch.float32) | |
x = torch.flip(x, dims=(0,)) | |
x[0, :, :] -= 91.4953 | |
x[1, :, :] -= 103.8827 | |
x[2, :, :] -= 131.0912 | |
return x | |
def get_img_torch(img, target_size=(224, 224)): | |
transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()]) | |
img = img.resize(target_size, Image.Resampling.NEAREST) | |
img = transform(img) | |
img = torch.unsqueeze(img, 0) | |
return img | |
return get_img_torch(fp) | |
def get_idx_frames_in_windows( | |
frames: list[int], window: dict, fps: int, sr: int = 16000 | |
) -> list[list]: | |
frames_in_windows = [ | |
idx | |
for idx, frame in enumerate(frames) | |
if window["start"] * fps / sr <= frame < window["end"] * fps / sr | |
] | |
return frames_in_windows | |
# Maxim code | |
def slice_audio( | |
start_time: float, | |
end_time: float, | |
win_max_length: float, | |
win_shift: float, | |
win_min_length: float, | |
) -> list[dict]: | |
"""Slices audio on windows | |
Args: | |
start_time (float): Start time of audio | |
end_time (float): End time of audio | |
win_max_length (float): Window max length | |
win_shift (float): Window shift | |
win_min_length (float): Window min length | |
Returns: | |
list[dict]: List of dict with timings, f.e.: {'start': 0, 'end': 12} | |
""" | |
if end_time < start_time: | |
return [] | |
elif (end_time - start_time) > win_max_length: | |
timings = [] | |
while start_time < end_time: | |
end_time_chunk = start_time + win_max_length | |
if end_time_chunk < end_time: | |
timings.append({"start": start_time, "end": end_time_chunk}) | |
elif end_time_chunk == end_time: # if tail exact `win_max_length` seconds | |
timings.append({"start": start_time, "end": end_time_chunk}) | |
break | |
else: # if tail less then `win_max_length` seconds | |
if ( | |
end_time - start_time < win_min_length | |
): # if tail less then `win_min_length` seconds | |
break | |
timings.append({"start": start_time, "end": end_time}) | |
break | |
start_time += win_shift | |
return timings | |
else: | |
return [{"start": start_time, "end": end_time}] | |
def convert_video_to_audio(file_path: str, sr: int = 16000) -> str: | |
path_save = file_path.split(".")[0] + ".wav" | |
if not os.path.exists(path_save): | |
ffmpeg_command = f"ffmpeg -y -i {file_path} -async 1 -vn -acodec pcm_s16le -ar {sr} {path_save}" | |
subprocess.call(ffmpeg_command, shell=True) | |
return path_save | |
def find_nearest_frames(target_frames, all_frames): | |
nearest_frames = [] | |
for frame in target_frames: | |
pos = bisect.bisect_left(all_frames, frame) | |
if pos == 0: | |
nearest_frame = all_frames[0] | |
elif pos == len(all_frames): | |
nearest_frame = all_frames[-1] | |
else: | |
before = all_frames[pos - 1] | |
after = all_frames[pos] | |
nearest_frame = before if frame - before <= after - frame else after | |
nearest_frames.append(nearest_frame) | |
return nearest_frames | |
def find_intersections( | |
x: list[dict], y: list[dict], min_length: float = 0 | |
) -> list[dict]: | |
"""Find intersections of two lists of dicts with intervals, preserving structure of `x` and adding intersection info | |
Args: | |
x (list[dict]): First list of intervals | |
y (list[dict]): Second list of intervals | |
min_length (float, optional): Minimum length of intersection. Defaults to 0. | |
Returns: | |
list[dict]: Windows with intersections, maintaining structure of `x`, and indicating intersection presence. | |
""" | |
timings = [] | |
j = 0 | |
for interval_x in x: | |
original_start = int(interval_x["start"]) | |
original_end = int(interval_x["end"]) | |
intersections_found = False | |
while j < len(y) and y[j]["end"] < original_start: | |
j += 1 # Skip any intervals in `y` that end before the current interval in `x` starts | |
# Check for all overlapping intervals in `y` | |
temp_j = ( | |
j # Temporary pointer to check intersections within `y` for current `x` | |
) | |
while temp_j < len(y) and y[temp_j]["start"] <= original_end: | |
# Calculate the intersection between `x[i]` and `y[j]` | |
intersection_start = max(original_start, y[temp_j]["start"]) | |
intersection_end = min(original_end, y[temp_j]["end"]) | |
if ( | |
intersection_start < intersection_end | |
and (intersection_end - intersection_start) >= min_length | |
): | |
timings.append( | |
{ | |
"original_start": original_start, | |
"original_end": original_end, | |
"start": intersection_start, | |
"end": intersection_end, | |
"speech": True, | |
} | |
) | |
intersections_found = True | |
temp_j += 1 # Move to the next interval in `y` for further intersections | |
# If no intersections were found, add the interval with `intersected` set to False | |
if not intersections_found: | |
timings.append( | |
{ | |
"original_start": original_start, | |
"original_end": original_end, | |
"start": None, | |
"end": None, | |
"speech": False, | |
} | |
) | |
return timings | |
# Anastasia code | |
class ASRModel: | |
def __init__(self, checkpoint_path: str, device: torch.device): | |
self.processor = WhisperProcessor.from_pretrained(checkpoint_path) | |
self.model = WhisperForConditionalGeneration.from_pretrained( | |
checkpoint_path | |
).to(device) | |
self.device = device | |
self.model.config.forced_decoder_ids = None | |
def __call__( | |
self, sample: torch.Tensor, audio_windows: dict, sr: int = 16000 | |
) -> tuple: | |
texts = [] | |
for t in range(len(audio_windows)): | |
input_features = self.processor( | |
sample[audio_windows[t]["start"] : audio_windows[t]["end"]], | |
sampling_rate=sr, | |
return_tensors="pt", | |
).input_features | |
predicted_ids = self.model.generate(input_features.to(self.device)) | |
transcription = self.processor.batch_decode( | |
predicted_ids, skip_special_tokens=False | |
) | |
curr_text = re.findall(r"> ([^<>]+)", transcription[0]) | |
if curr_text: | |
texts.append(curr_text) | |
else: | |
texts.appemd("") | |
# for drawing | |
input_features = self.processor( | |
sample, sampling_rate=sr, return_tensors="pt" | |
).input_features | |
predicted_ids = self.model.generate(input_features.to(self.device)) | |
transcription = self.processor.batch_decode( | |
predicted_ids, skip_special_tokens=False | |
) | |
total_text = re.findall(r"> ([^<>]+)", transcription[0]) | |
return texts, total_text | |
def convert_webm_to_mp4(input_file): | |
path_save = input_file.split(".")[0] + ".mp4" | |
if not os.path.exists(path_save): | |
ff_video = "ffmpeg -i {} -c:v copy -c:a aac -strict experimental {}".format( | |
input_file, path_save | |
) | |
subprocess.call(ff_video, shell=True) | |
return path_save | |