""" File: utils.py Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov Description: Utility functions. License: MIT License """ import time import torch import os import subprocess import bisect import re import requests from torchvision import transforms from PIL import Image from transformers import WhisperProcessor, WhisperForConditionalGeneration from pathlib import Path from contextlib import suppress from urllib.parse import urlparse from contextlib import ContextDecorator from typing import Callable class Timer(ContextDecorator): """Context manager for measuring code execution time""" def __enter__(self): self.start = time.time() return self def __exit__(self, *args): self.end = time.time() self.execution_time = f"{self.end - self.start:.2f} seconds" def __str__(self): return self.execution_time def load_model( model_url: str, folder_path: str, force_reload: bool = False ) -> str | None: file_name = Path(urlparse(model_url).path).name file_path = Path(folder_path) / file_name if file_path.exists() and not force_reload: return str(file_path) with suppress(Exception), requests.get(model_url, stream=True) as response: file_path.parent.mkdir(parents=True, exist_ok=True) with file_path.open("wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) return str(file_path) return None def readetect_speech( file_path: str, read_audio: Callable, get_speech_timestamps: Callable, vad_model: torch.jit.ScriptModule, sr: int = 16000, ) -> list[dict]: wav = read_audio(file_path, sampling_rate=sr) # get speech timestamps from full audio file speech_timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sr) return wav, speech_timestamps def calculate_mode(series): mode = series.mode() return mode[0] if not mode.empty else None def pth_processing(fp): class PreprocessInput(torch.nn.Module): def init(self): super(PreprocessInput, self).init() def forward(self, x): x = x.to(torch.float32) x = torch.flip(x, dims=(0,)) x[0, :, :] -= 91.4953 x[1, :, :] -= 103.8827 x[2, :, :] -= 131.0912 return x def get_img_torch(img, target_size=(224, 224)): transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()]) img = img.resize(target_size, Image.Resampling.NEAREST) img = transform(img) img = torch.unsqueeze(img, 0) return img return get_img_torch(fp) def get_idx_frames_in_windows( frames: list[int], window: dict, fps: int, sr: int = 16000 ) -> list[list]: frames_in_windows = [ idx for idx, frame in enumerate(frames) if window["start"] * fps / sr <= frame < window["end"] * fps / sr ] return frames_in_windows # Maxim code def slice_audio( start_time: float, end_time: float, win_max_length: float, win_shift: float, win_min_length: float, ) -> list[dict]: """Slices audio on windows Args: start_time (float): Start time of audio end_time (float): End time of audio win_max_length (float): Window max length win_shift (float): Window shift win_min_length (float): Window min length Returns: list[dict]: List of dict with timings, f.e.: {'start': 0, 'end': 12} """ if end_time < start_time: return [] elif (end_time - start_time) > win_max_length: timings = [] while start_time < end_time: end_time_chunk = start_time + win_max_length if end_time_chunk < end_time: timings.append({"start": start_time, "end": end_time_chunk}) elif end_time_chunk == end_time: # if tail exact `win_max_length` seconds timings.append({"start": start_time, "end": end_time_chunk}) break else: # if tail less then `win_max_length` seconds if ( end_time - start_time < win_min_length ): # if tail less then `win_min_length` seconds break timings.append({"start": start_time, "end": end_time}) break start_time += win_shift return timings else: return [{"start": start_time, "end": end_time}] def convert_video_to_audio(file_path: str, sr: int = 16000) -> str: path_save = file_path.split(".")[0] + ".wav" if not os.path.exists(path_save): ffmpeg_command = f"ffmpeg -y -i {file_path} -async 1 -vn -acodec pcm_s16le -ar {sr} {path_save}" subprocess.call(ffmpeg_command, shell=True) return path_save def find_nearest_frames(target_frames, all_frames): nearest_frames = [] for frame in target_frames: pos = bisect.bisect_left(all_frames, frame) if pos == 0: nearest_frame = all_frames[0] elif pos == len(all_frames): nearest_frame = all_frames[-1] else: before = all_frames[pos - 1] after = all_frames[pos] nearest_frame = before if frame - before <= after - frame else after nearest_frames.append(nearest_frame) return nearest_frames def find_intersections( x: list[dict], y: list[dict], min_length: float = 0 ) -> list[dict]: """Find intersections of two lists of dicts with intervals, preserving structure of `x` and adding intersection info Args: x (list[dict]): First list of intervals y (list[dict]): Second list of intervals min_length (float, optional): Minimum length of intersection. Defaults to 0. Returns: list[dict]: Windows with intersections, maintaining structure of `x`, and indicating intersection presence. """ timings = [] j = 0 for interval_x in x: original_start = int(interval_x["start"]) original_end = int(interval_x["end"]) intersections_found = False while j < len(y) and y[j]["end"] < original_start: j += 1 # Skip any intervals in `y` that end before the current interval in `x` starts # Check for all overlapping intervals in `y` temp_j = ( j # Temporary pointer to check intersections within `y` for current `x` ) while temp_j < len(y) and y[temp_j]["start"] <= original_end: # Calculate the intersection between `x[i]` and `y[j]` intersection_start = max(original_start, y[temp_j]["start"]) intersection_end = min(original_end, y[temp_j]["end"]) if ( intersection_start < intersection_end and (intersection_end - intersection_start) >= min_length ): timings.append( { "original_start": original_start, "original_end": original_end, "start": intersection_start, "end": intersection_end, "speech": True, } ) intersections_found = True temp_j += 1 # Move to the next interval in `y` for further intersections # If no intersections were found, add the interval with `intersected` set to False if not intersections_found: timings.append( { "original_start": original_start, "original_end": original_end, "start": None, "end": None, "speech": False, } ) return timings # Anastasia code class ASRModel: def __init__(self, checkpoint_path: str, device: torch.device): self.processor = WhisperProcessor.from_pretrained(checkpoint_path) self.model = WhisperForConditionalGeneration.from_pretrained( checkpoint_path ).to(device) self.device = device self.model.config.forced_decoder_ids = None def __call__( self, sample: torch.Tensor, audio_windows: dict, sr: int = 16000 ) -> tuple: texts = [] for t in range(len(audio_windows)): input_features = self.processor( sample[audio_windows[t]["start"] : audio_windows[t]["end"]], sampling_rate=sr, return_tensors="pt", ).input_features predicted_ids = self.model.generate(input_features.to(self.device)) transcription = self.processor.batch_decode( predicted_ids, skip_special_tokens=False ) curr_text = re.findall(r"> ([^<>]+)", transcription[0]) if curr_text: texts.append(curr_text) else: texts.appemd("") # for drawing input_features = self.processor( sample, sampling_rate=sr, return_tensors="pt" ).input_features predicted_ids = self.model.generate(input_features.to(self.device)) transcription = self.processor.batch_decode( predicted_ids, skip_special_tokens=False ) total_text = re.findall(r"> ([^<>]+)", transcription[0]) return texts, total_text def convert_webm_to_mp4(input_file): path_save = input_file.split(".")[0] + ".mp4" if not os.path.exists(path_save): ff_video = "ffmpeg -i {} -c:v copy -c:a aac -strict experimental {}".format( input_file, path_save ) subprocess.call(ff_video, shell=True) return path_save