Zeph27's picture
add func parameter
40a56d1
raw
history blame
9.44 kB
import gradio as gr
import yt_dlp
from dotenv import load_dotenv
import os
import google.generativeai as genai
import re
import torch
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
import time
import spaces
load_dotenv()
default_gemini_api_key = os.getenv('gemini_api_key')
device = 0 if torch.cuda.is_available() else "cpu"
def load_pipeline(model_name):
return pipeline(
task="automatic-speech-recognition",
model=model_name,
chunk_length_s=30,
device=device,
)
def configure_genai(api_key, model_variant):
genai.configure(api_key=api_key)
return genai.GenerativeModel(model_variant)
def extract_youtube_id(youtube_url):
# Extract the YouTube video ID from various URL formats
youtube_id_match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', youtube_url)
if youtube_id_match:
return youtube_id_match.group(1)
return None
def download_youtube_audio(youtube_url, output_filename):
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'outtmpl': output_filename,
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([youtube_url])
print(f"Downloaded audio from YouTube URL: {youtube_url}")
return output_filename
except Exception as e:
print(f"Error downloading YouTube audio: {str(e)}")
raise gr.Error(f"Failed to download YouTube audio: {str(e)}")
def summarize_transcription(transcription, model, gemini_prompt):
try:
prompt = f"{gemini_prompt}:\n\n{transcription}"
response = model.generate_content(prompt)
return response.text
except Exception as e:
print(f"Error summarizing transcription: {str(e)}")
return f"Error summarizing transcription: {str(e)}"
@spaces.GPU(duration=120)
def process_audio(audio_file, language, whisper_model):
print("Starting transcription...")
if device == 0:
pipe = load_pipeline(whisper_model)
else:
pipe = load_pipeline("openai/whisper-tiny")
with open(audio_file, "rb") as f:
inputs = f.read()
inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
if language:
print(f"Using language: {language}")
transcription = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe", "language": language}, return_timestamps=True)["text"]
else:
print("No language defined, using default language")
transcription = pipe(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
return transcription
def transcribe(youtube_url, audio_file, whisper_model, gemini_api_key, gemini_prompt, gemini_model_variant, language, progress=gr.Progress()):
start_time = time.time()
try:
progress(0, desc="Initializing")
if not gemini_api_key:
gemini_api_key = default_gemini_api_key
model = configure_genai(gemini_api_key, gemini_model_variant)
if youtube_url:
progress(0.1, desc="Extracting YouTube ID")
youtube_id = extract_youtube_id(youtube_url)
if youtube_id:
output_filename = f"{youtube_id}"
else:
output_filename = f"unknown"
progress(0.2, desc="Downloading YouTube audio")
audio_file = download_youtube_audio(youtube_url, output_filename)
audio_file = f"{audio_file}.mp3"
print(f"Audio file downloaded: {audio_file}")
else:
progress(0.2, desc="Reading audio file")
audio_file = f"{audio_file.name}"
print(f"Audio file read: {audio_file}")
progress(0.4, desc="Starting transcription")
transcription = process_audio(audio_file, language, whisper_model)
progress(0.6, desc="Cleaning up")
# Delete the audio file after transcription
if os.path.exists(f"{audio_file}.mp3"):
os.remove(f"{audio_file}.mp3")
print(f"Deleted audio file: {audio_file}.mp3")
progress(0.7, desc="Summarizing transcription")
# Summarize the transcription
summary = summarize_transcription(transcription, model, gemini_prompt)
progress(0.8, desc="Preparing output")
# Prepare the transcription and summary message
transcription_message = f"{transcription}" if transcription else ""
summary_message = f"{summary}" if summary else ""
progress(0.9, desc="Saving output to file")
print("Saving transcription and summary to file...")
# Save transcription and summary to separate text files
transcription_file = "transcription_output.txt"
summary_file = "summary_output.txt"
with open(transcription_file, "w", encoding="utf-8") as f:
f.write(transcription_message)
with open(summary_file, "w", encoding="utf-8") as f:
f.write(summary_message)
progress(1, desc="Complete")
print("Transcription and summarization complete.")
end_time = time.time()
total_time = round(end_time - start_time, 2)
return transcription_message, summary_message, transcription_file, summary_file, total_time
except gr.Error as e:
# Re-raise Gradio errors
raise e
except Exception as e:
print(f"Error during transcription or summarization: {str(e)}")
raise gr.Error(f"Transcription or summarization failed: {str(e)}")
def toggle_input(choice):
if choice == "YouTube URL":
return gr.update(visible=True), gr.update(visible=False, value=None)
else:
return gr.update(visible=False, value=None), gr.update(visible=True)
def toggle_language(choice):
if choice == True:
return gr.update(visible=True, value="id")
else:
return gr.update(visible=False, value="")
with gr.Blocks(theme='NoCrypt/miku') as demo:
gr.Label('Youtube Summarizer WebUI created with ❤️ by Ryusui', show_label=False)
with gr.Accordion("Input"):
with gr.Column():
input_type = gr.Radio(["YouTube URL", "Audio File"], label="Input Type", value="Audio File", info="Please consider using the audio file if you face any issues with the YouTube URL. Currently youtube is banning HuggingFace IP Addresses.")
with gr.Row():
youtube_url = gr.Textbox(label="YouTube URL", visible=False, info="Input the full URL of the YouTube video you want to transcribe and summarize. Example: https://www.youtube.com/watch?v=VIDEO_ID")
audio_file = gr.File(label="Upload Audio File", visible=True, file_types=['.wav', '.flac', '.mp3'])
whisper_model = gr.Dropdown(["openai/whisper-tiny", "openai/whisper-base", "openai/whisper-small", "openai/whisper-medium", "openai/whisper-large-v3", "distil-whisper/distil-large-v3"], label="Whisper Model", value="distil-whisper/distil-large-v3", info="Tiny is the fastest model, but it's not the best quality. large-v3 is the best quality, but it's the slowest model.")
gemini_model_variant = gr.Dropdown(["gemini-1.5-flash", "gemini-1.5-pro"], label="Gemini Model Variant", value="gemini-1.5-pro", info="Gemini-1.5-flash is the fastest model, but it's not the best quality. Gemini-1.5-pro is the best quality, but it's slower")
define_language = gr.Checkbox(label="Define Language", value=False, info="If you want to define the language, check this box")
language = gr.Dropdown(["id","en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], label="Language", value=None, info="Select the language for transcription", visible=False)
gemini_api_key = gr.Textbox(label="Gemini API Key (Optional)", placeholder="Enter your Gemini API key or leave blank to use default", info="If you facing error on transcription, please try to use your own API key")
gemini_prompt = gr.Textbox(label="Gemini Prompt", value="Buatkan resume dari transkrip ini")
transcribe_button = gr.Button("Transcribe and Summarize")
with gr.Accordion("Output"):
with gr.Column():
transcription_output = gr.Textbox(label="Transcription Output")
summary_output = gr.Textbox(label="Summary Output")
transcription_file = gr.File(label="Download Transcription")
summary_file = gr.File(label="Download Summary")
processing_time = gr.Textbox(label="Total Processing Time (seconds)")
input_type.change(fn=toggle_input, inputs=input_type, outputs=[youtube_url, audio_file])
define_language.change(fn=toggle_language, inputs=define_language, outputs=[language])
transcribe_button.click(
fn=transcribe,
inputs=[
youtube_url,
audio_file,
whisper_model,
gemini_api_key,
gemini_prompt,
gemini_model_variant,
language,
],
outputs=[transcription_output, summary_output, transcription_file, summary_file, processing_time]
)
print("Launching Gradio interface...")
demo.launch()