Spaces:
Runtime error
Runtime error
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, pipeline | |
from pydub import AudioSegment | |
import whisper | |
from settings import MODEL_PARSER | |
from pytube import YouTube | |
class BagOfModels: | |
'''model -> is a model from hugging face | |
model_names -> modelnames that can be chosen from in streamlit | |
model_settinsg -> settings of model that can be customized by user | |
''' | |
args = MODEL_PARSER | |
barfs = 5 | |
def __init__(self,model,model_names,model_settings,model_tasks, **kwargs): | |
self.model = model | |
self.model_names = model_names | |
self.model_settings = model_settings | |
self.model_tasks = model_tasks | |
self.kwargs = kwargs | |
def get_model_settings(cls): | |
bag_of_models = BagOfModels(**vars(cls.args)) | |
return bag_of_models.model_settings | |
def get_model_names(cls): | |
bag_of_models = BagOfModels(**vars(cls.args)) | |
return bag_of_models.model_names | |
def get_model(cls): | |
bag_of_models = BagOfModels(**vars(cls.args)) | |
return bag_of_models.model | |
def get_model_tasks(cls): | |
bag_of_models = BagOfModels(**vars(cls.args)) | |
return bag_of_models.model_tasks | |
def load_model(cls,model_name,**kwargs): | |
bag_of_models = BagOfModels(**vars(cls.args)) | |
cls.model = bag_of_models.model | |
assert model_name in bag_of_models.model_names, f"please pick one of the available models: {bag_of_models.model_names}" | |
return Model(model_name,**cls.model[model_name]) | |
class Model: | |
def __init__(self,model_name,task,url,**kwargs): | |
self.url = url | |
self.model_name = model_name | |
self.name = self.url.split("https://huggingface.co/")[1] | |
self.task = task | |
self.kwargs = kwargs | |
self.init_optional_args(**self.kwargs) | |
def init_optional_args(self,year=None,description=None): | |
self._year = year | |
self._description = description | |
def predict_stt(self,source,source_type,model_task): | |
model = whisper.load_model(self.model_name.split("_")[1]) #tiny - base - medium | |
stt = SoundToText(source,source_type,model_task,model=model,tokenizer=None) | |
stt.whisper() | |
return stt | |
def predict_summary(self): | |
tokenizer = Wav2Vec2Processor.from_pretrained(self.name) | |
model = Wav2Vec2ForCTC.from_pretrained(self.name) # Note: PyTorch Model | |
class Transcription(): | |
def __init__(self,model,source,source_type) -> None: | |
pass | |
class SoundToText(): | |
def __init__(self,source,source_type,model_task,model,tokenizer=None): | |
self.source = source | |
self.source_type = source_type | |
self.model = model | |
self.model_task = model_task | |
self.tokenizer = tokenizer | |
def wav2vec(self,size): | |
pass | |
def wav2vec2(self,size): | |
pass | |
def whisper(self): | |
# download youtube url | |
if self.source_type == "YouTube": | |
self.audio_path = YouTube(self.source).streams.get_by_itag(140).download("output/", filename="audio") | |
# if self.source_type == "File": | |
# audio = None | |
# if self.source.name.endswith('.wav'): audio = AudioSegment.from_wav(self.source) | |
# elif self.source.name.endswith('.mp3'): audio = AudioSegment.from_mp3(self.source) | |
# audio.export('output/audio.wav', format='wav') | |
# self.audio_path = "output/audio.wav" | |
model = whisper.load_model("base") | |
self.raw_output = model.transcribe(self.audio_path,verbose=True) | |
self.text = self.raw_output["text"] | |
self.language = self.raw_output["language"] | |
self.segments = self.raw_output["segments"] | |
# Remove token ids from the output | |
for segment in self.segments: | |
del segment["tokens"] | |
self.transcribed = True | |
class TextToSummary(): | |
def __init__(self,input_text,min_length,max_length): | |
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
self.summary_input = input_text | |
self.summary_output = (self.summarizer(self.summary_input, min_length=min_length, max_length=max_length, do_sample=False)) | |
def get_summary(self): | |
return self.summary_output | |
def wav2vec(self): | |
pass | |
def record(model_name): | |
args = MODEL_PARSER | |
models = BagOfModels.get_model_names() | |
tasks = BagOfModels.get_model_tasks() | |
whisper_base = BagOfModels.load_model(model_name,**vars(args)) | |
whisper_base.predict() | |
if __name__== "__main__": | |
args = MODEL_PARSER | |
models = BagOfModels.get_model_names() | |
tasks = BagOfModels.get_model_tasks() | |
whisper_base = BagOfModels.load_model("whisper_base",**vars(args)) | |
whisper_base.predict_stt() |