whisper-youtube-2-hf_dataset / datapipeline.py
juancopi81's picture
Duplicate from Whispering-GPT/whisper-youtube-2-hf_dataset
7288748
from typing import Dict, List
from pathlib import Path
from sqlite3 import Cursor
from utils import accepts_types, create_videos
from preprocessing.youtubevideopreprocessor import YoutubeVideoPreprocessor
from loading.loaderiterator import LoaderIterator
from transforming.batchtransformer import BatchTransformer
from storing.sqlitebatchvideostorer import SQLiteBatchVideoStorer
from storing.sqlitecontextmanager import SQLiteContextManager
from loading.serialization import JsonSerializer
from transforming.addtitletransform import AddTitleTransform
from transforming.adddescriptiontransform import AddDescriptionTransform
from transforming.whispertransform import WhisperTransform
class DataPipeline:
"""A class that wraps the different components of the system. It processes
data using these steps: load -> apply transform -> store.
"""
def __init__(self,
loader_iterator: LoaderIterator,
batch_transformer: BatchTransformer,
storer: SQLiteBatchVideoStorer,
sqlite_context_manager: SQLiteContextManager) -> None:
self.loader_iterator = loader_iterator
self.batch_transformer = batch_transformer
self.storer = storer
self.sqlite_context_manager = sqlite_context_manager
@accepts_types(list)
def process(self, load_paths: List[Path]) -> None:
"""Process files in batches: load -> transform -> store to db."""
self.loader_iterator.load_paths = load_paths
with self.sqlite_context_manager as db_cursor:
for video_data_batch in self.loader_iterator:
self._process_video_batch(db_cursor, video_data_batch)
def _process_video_batch(self,
db_cursor: Cursor,
video_data_batch: List[Dict]) -> None:
videos = create_videos(video_data_batch)
transformed_videos = self.batch_transformer.apply(videos)
self.storer.store(db_cursor, transformed_videos)
def create_hardcoded_data_pipeline(db_path, whisper_model: str="base") -> DataPipeline:
"""Factory function to create a DataPipeline with
default arguments.
TODO: Create DataPipeline so users can pass the args.
"""
loader_iterator = LoaderIterator(JsonSerializer(), 2)
# Whisper transform using based model and timestamps
# TODO: Let user select this parameters.
batch_transformer = BatchTransformer([AddTitleTransform(),
AddDescriptionTransform(),
WhisperTransform(model=whisper_model)])
video_storer = SQLiteBatchVideoStorer()
sqlite_context_manager = SQLiteContextManager(db_path)
return DataPipeline(loader_iterator,
batch_transformer,
video_storer,
sqlite_context_manager)