Spaces:
Running
Running
import torch | |
import gradio as gr | |
import yt_dlp as youtube_dl | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
import tempfile | |
import os | |
import time | |
import requests | |
from playwright.sync_api import sync_playwright | |
from languages import get_language_names | |
from subtitle import text_output, subtitle_output | |
import subprocess | |
try: | |
import spaces | |
USING_SPACES = True | |
except ImportError: | |
USING_SPACES = False | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
) | |
os.system("playwright install") | |
YT_LENGTH_LIMIT_S = 360 | |
SPACES_GPU_DURATION = 90 | |
device = 0 if torch.cuda.is_available() else "cpu" | |
def gpu_decorator(duration=60): | |
def actual_decorator(func): | |
if USING_SPACES: | |
return spaces.GPU(duration=duration)(func) | |
return func | |
return actual_decorator | |
def device_info(): | |
try: | |
subprocess.run(["df", "-h"], check=True) | |
subprocess.run(["lsblk"], check=True) | |
subprocess.run(["free", "-h"], check=True) | |
subprocess.run(["lscpu"], check=True) | |
subprocess.run(["nvidia-smi"], check=True) | |
except subprocess.CalledProcessError as e: | |
print(f"Command failed: {e}") | |
from gpustat import GPUStatCollection | |
def update_gpu_status(): | |
if torch.cuda.is_available() == False: | |
return "No Nvidia Device" | |
try: | |
gpu_stats = GPUStatCollection.new_query() | |
for gpu in gpu_stats: | |
# Assuming you want to monitor the first GPU, index 0 | |
gpu_id = gpu.index | |
gpu_name = gpu.name | |
gpu_utilization = gpu.utilization | |
memory_used = gpu.memory_used | |
memory_total = gpu.memory_total | |
memory_utilization = (memory_used / memory_total) * 100 | |
gpu_status=(f"> **GPU** {gpu_id}: {gpu_name}, Utilization: {gpu_utilization}%, **Memory Used**: {memory_used}MB, **Memory Total**: {memory_total}MB, **Memory Utilization**: {memory_utilization:.2f}%") | |
return gpu_status | |
except Exception as e: | |
print(f"Error getting GPU stats: {e}") | |
return torch_update_gpu_status() | |
def torch_update_gpu_status(): | |
if torch.cuda.is_available(): | |
gpu_info = torch.cuda.get_device_name(0) | |
gpu_memory = torch.cuda.mem_get_info(0) | |
total_memory = gpu_memory[1] / (1024 * 1024) | |
free_memory=gpu_memory[0] /(1024 *1024) | |
used_memory = (gpu_memory[1] - gpu_memory[0]) / (1024 * 1024) | |
gpu_status = f"> **GPU**: {gpu_info} **Free Memory**:{free_memory}MB **Total Memory**: {total_memory:.2f} MB **Used Memory**: {used_memory:.2f} MB" | |
else: | |
gpu_status = "No GPU available" | |
return gpu_status | |
def update_cpu_status(): | |
import datetime | |
current_time = datetime.datetime.now().time() | |
time_str = current_time.strftime("%H:%M:%S") | |
cpu_percent = psutil.cpu_percent() | |
cpu_status = f"> **CPU Usage: {cpu_percent}% {time_str}" | |
return cpu_status | |
def update_status(): | |
gpu_status = update_gpu_status() | |
cpu_status = update_cpu_status() | |
sys_status=gpu_status+"\n"+cpu_status | |
return sys_status | |
def refresh_status(): | |
return update_status() | |
def transcribe(inputs, model, language, batch_size, chunk_length_s, stride_length_s, task, timestamp_mode, progress=gr.Progress(track_tqdm=True)): | |
try: | |
if inputs is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
torch_dtype = torch.float16 | |
model_gen = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model_gen.to(device) | |
processor = AutoProcessor.from_pretrained(model) | |
tokenizer = WhisperTokenizer.from_pretrained(model) | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=model_gen, | |
chunk_length_s=chunk_length_s, | |
stride_length_s=stride_length_s, | |
tokenizer=tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
model_kwargs={"attn_implementation": "flash_attention_2"}, | |
device=device, | |
) | |
generate_kwargs = {} | |
if language != "Automatic Detection" and model.endswith(".en") == False: | |
generate_kwargs["language"] = language | |
if model.endswith(".en") == False: | |
generate_kwargs["task"] = task | |
output = pipe(inputs, batch_size=batch_size, generate_kwargs=generate_kwargs, return_timestamps=timestamp_mode) | |
print(output) | |
print({"inputs": inputs, "model": model, "language": language, "batch_size": batch_size, "chunk_length_s": chunk_length_s, "stride_length_s": stride_length_s, "task": task, "timestamp_mode": timestamp_mode}) | |
if not timestamp_mode: | |
text = output['text'] | |
return text_output(inputs, text) | |
else: | |
chunks = output['chunks'] | |
return subtitle_output(inputs, chunks) | |
except Exception as e: | |
error_message = str(e) | |
raise gr.Error(error_message, duration=20) | |
def _download_yt_audio(yt_url, filename): | |
info_loader = youtube_dl.YoutubeDL() | |
try: | |
info = info_loader.extract_info(yt_url, download=False) | |
except youtube_dl.utils.DownloadError as err: | |
raise gr.Error(str(err)) | |
file_length = info.get("duration_string") | |
if not file_length: | |
raise gr.Error("Video duration is unavailable.") | |
file_h_m_s = file_length.split(":") | |
file_h_m_s = [int(sub_length) for sub_length in file_h_m_s] | |
if len(file_h_m_s) == 1: | |
file_h_m_s.insert(0, 0) | |
if len(file_h_m_s) == 2: | |
file_h_m_s.insert(0, 0) | |
file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2] | |
if file_length_s > YT_LENGTH_LIMIT_S: | |
yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S)) | |
file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s)) | |
raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.", duration=20) | |
try: | |
ydl_opts = { | |
"outtmpl": filename, | |
"format": "bestaudio[ext=m4a]/best", | |
} | |
with youtube_dl.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([yt_url]) | |
except youtube_dl.utils.ExtractorError as err: | |
available_formats = info_loader.extract_info(yt_url, download=False)['formats'] | |
raise gr.Error(f"Requested format not available. Available formats: {available_formats}", duration=20) | |
def _return_yt_video_id(yt_url): | |
if "https://www.youtube.com/watch?v=" in yt_url: | |
video_id = yt_url.split("?v=")[-1] | |
elif "https://youtu.be/" in yt_url: | |
video_id = yt_url.split("be/")[1] | |
return video_id | |
def _return_yt_html_embed(yt_url): | |
video_id = _return_yt_video_id(yt_url) | |
HTML_str = ( | |
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' | |
" </center>" | |
) | |
return HTML_str | |
def _return_yt_thumbnail(yt_url): | |
video_id = _return_yt_video_id(yt_url) | |
if not video_id: | |
raise ValueError("Invalid YouTube URL: Unable to extract video ID.") | |
thumbnail_url = f"https://img.youtube.com/vi/{video_id}/maxresdefault.jpg" | |
thumbnail_path = None | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: | |
response = requests.get(thumbnail_url) | |
if response.status_code == 200: | |
temp_file.write(response.content) | |
thumbnail_path = temp_file.name | |
else: | |
raise Exception(f"Failed to retrieve thumbnail. Status code: {response.status_code}") | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
return None | |
return thumbnail_path | |
def _return_yt_info(yt_url): | |
video_id = _return_yt_video_id(yt_url) | |
try: | |
with sync_playwright() as p: | |
browser = p.chromium.launch(headless=True) | |
page = browser.new_page() | |
page.goto(yt_url) | |
page.wait_for_load_state("networkidle") | |
title = page.title() | |
description = page.query_selector("meta[name='description']").get_attribute("content") | |
keywords = page.query_selector("meta[name='keywords']").get_attribute("content") | |
gr_title = gr.Textbox(label="YouTube Title", visible=True, value=title) | |
gr_description = gr.Textbox(label="YouTube Description", visible=True, value=description) | |
gr_keywords = gr.Textbox(label="YouTube Keywords", visible=True, value=keywords) | |
browser.close() | |
return gr_title, gr_description, gr_keywords | |
except Exception as e: | |
print(e) | |
return gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False) | |
def return_youtube(yt_url): | |
html_embed_str = _return_yt_html_embed(yt_url) | |
thumbnail = _return_yt_thumbnail(yt_url) | |
gr_html = gr.HTML(label="Youtube Video", visible=True, value=html_embed_str) | |
gr_thumbnail = gr.Image(label="Youtube Thumbnail", visible=True, value=thumbnail) | |
gr_title, gr_description, gr_keywords = _return_yt_info(yt_url) | |
return gr_html, gr_thumbnail, gr_title, gr_description, gr_keywords | |
def yt_transcribe(yt_url, model, language, batch_size, chunk_length_s, stride_length_s, task, timestamp_mode): | |
gr_html, gr_thumbnail, gr_title, gr_description, gr_keywords = return_youtube(yt_url) | |
try: | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
filepath = os.path.join(tmpdirname, "video.mp4") | |
_download_yt_audio(yt_url, filepath) | |
with open(filepath, "rb") as f: | |
inputs = f.read() | |
inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate) | |
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate} | |
torch_dtype = torch.float16 | |
model_gen = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
) | |
model_gen.to(device) | |
processor = AutoProcessor.from_pretrained(model) | |
tokenizer = WhisperTokenizer.from_pretrained(model) | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=model_gen, | |
chunk_length_s=chunk_length_s, | |
stride_length_s=stride_length_s, | |
tokenizer=tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
model_kwargs={"attn_implementation": "flash_attention_2"}, | |
device=device, | |
) | |
generate_kwargs = {} | |
if language != "Automatic Detection" and model.endswith(".en") == False: | |
generate_kwargs["language"] = language | |
if model.endswith(".en") == False: | |
generate_kwargs["task"] = task | |
output = pipe(inputs, batch_size=batch_size, generate_kwargs=generate_kwargs, return_timestamps=timestamp_mode) | |
print(output) | |
print({"inputs": yt_url, "model": model, "language": language, "batch_size": batch_size, "chunk_length_s": chunk_length_s, "stride_length_s": stride_length_s, "task": task, "timestamp_mode": timestamp_mode}) | |
if not timestamp_mode: | |
text = output['text'] | |
subtitle, files = text_output(inputs, text) | |
else: | |
chunks = output['chunks'] | |
subtitle, files = subtitle_output(inputs, chunks) | |
return subtitle, files, gr_title, gr_html, gr_thumbnail, gr_description, gr_keywords | |
except Exception as e: | |
error_message = str(e) | |
gr.Warning(error_message, duration=20) | |
return gr.Textbox(visible=False),gr.Textbox(visible=False), gr_title, gr_html, gr_thumbnail, gr_description, gr_keywords | |
demo = gr.Blocks() | |
file_transcribe = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Audio(sources=['upload', 'microphone'], type="filepath", label="Audio file"), | |
gr.Dropdown( | |
choices=[ | |
"openai/whisper-tiny", | |
"openai/whisper-base", | |
"openai/whisper-small", | |
"openai/whisper-medium", | |
"openai/whisper-large", | |
"openai/whisper-large-v1", | |
"openai/whisper-large-v2", "distil-whisper/distil-large-v2", | |
"openai/whisper-large-v3", "openai/whisper-large-v3-turbo", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", | |
], | |
value="openai/whisper-large-v3-turbo", | |
label="Model Name", | |
allow_custom_value=True, | |
), | |
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", interactive = True,), | |
gr.Slider(label="Batch Size", minimum=1, maximum=32, value=16, step=1), | |
gr.Slider(label="Chunk Length (s)", minimum=1, maximum=60, value=17.5, step=0.1), | |
gr.Slider(label="Stride Length (s)", minimum=1, maximum=30, value=1, step=0.1), | |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), | |
gr.Dropdown( | |
choices=[True, False, "word"], | |
value=True, | |
label="Timestamp Mode" | |
), | |
], | |
outputs=[gr.Textbox(label="Output"), gr.File(label="Download Files")], | |
title="Whisper: Transcribe Audio", | |
flagging_mode="auto", | |
) | |
video_transcribe = gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Video(sources=["upload", "webcam"], label="Video file", show_label=False, show_download_button=False, show_share_button=False, streaming=True), | |
gr.Dropdown( | |
choices=[ | |
"openai/whisper-tiny", | |
"openai/whisper-base", | |
"openai/whisper-small", | |
"openai/whisper-medium", | |
"openai/whisper-large", | |
"openai/whisper-large-v1", | |
"openai/whisper-large-v2", "distil-whisper/distil-large-v2", | |
"openai/whisper-large-v3", "openai/whisper-large-v3-turbo", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", | |
], | |
value="openai/whisper-large-v3-turbo", | |
label="Model Name", | |
allow_custom_value=True, | |
), | |
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", interactive = True,), | |
gr.Slider(label="Batch Size", minimum=1, maximum=32, value=16, step=1), | |
gr.Slider(label="Chunk Length (s)", minimum=1, maximum=60, value=17.5, step=0.1), | |
gr.Slider(label="Stride Length (s)", minimum=1, maximum=30, value=1, step=0.1), | |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), | |
gr.Dropdown( | |
choices=[True, False, "word"], | |
value=True, | |
label="Timestamp Mode" | |
), | |
], | |
outputs=[gr.Textbox(label="Output"), gr.File(label="Download Files")], | |
title="Whisper: Transcribe Video", | |
flagging_mode="auto", | |
) | |
yt_transcribe = gr.Interface( | |
fn=yt_transcribe, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), | |
gr.Dropdown( | |
choices=[ | |
"openai/whisper-tiny", | |
"openai/whisper-base", | |
"openai/whisper-small", | |
"openai/whisper-medium", | |
"openai/whisper-large", | |
"openai/whisper-large-v1", | |
"openai/whisper-large-v2", "distil-whisper/distil-large-v2", | |
"openai/whisper-large-v3", "openai/whisper-large-v3-turbo", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", | |
], | |
value="openai/whisper-large-v3-turbo", | |
label="Model Name", | |
allow_custom_value=True, | |
), | |
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", interactive = True,), | |
gr.Slider(label="Batch Size", minimum=1, maximum=32, value=16, step=1), | |
gr.Slider(label="Chunk Length (s)", minimum=1, maximum=60, value=17.5, step=0.1), | |
gr.Slider(label="Stride Length (s)", minimum=1, maximum=30, value=1, step=0.1), | |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), | |
gr.Dropdown( | |
choices=[True, False, "word"], | |
value=True, | |
label="Timestamp Mode" | |
), | |
], | |
outputs=[ | |
gr.Textbox(label="Output"), | |
gr.File(label="Download Files"), | |
gr.Textbox(label="Youtube Title"), | |
gr.HTML(label="Youtube Video"), | |
gr.Image(label="Youtube Thumbnail"), | |
gr.Textbox(label="Youtube Description"), | |
gr.Textbox(label="Youtube Keywords"), | |
], | |
title="Whisper: Transcribe YouTube", | |
flagging_mode="auto", | |
) | |
with demo: | |
gr.TabbedInterface( | |
interface_list=[file_transcribe, video_transcribe, yt_transcribe], | |
tab_names=["Audio", "Video", "YouTube"] | |
) | |
with gr.Group(): | |
sys_status_output = gr.Markdown(label="System Status", interactive=False) | |
refresh_button = gr.Button("Refresh System Status") | |
refresh_button.click(refresh_status, None, sys_status_output) | |
if __name__ == "__main__": | |
demo.queue().launch(ssr_mode=False) |