Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer, pipeline | |
from PIL import Image | |
from decord import VideoReader, cpu | |
import base64 | |
import io | |
import spaces | |
import time | |
import os | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
import moviepy.editor as mp | |
# Load models | |
model_path = 'openbmb/MiniCPM-V-2_6' | |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
model = model.to(device='cuda') | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model.eval() | |
# Load Whisper model | |
whisper_model = "openai/whisper-large-v3" | |
asr_pipeline = pipeline( | |
task="automatic-speech-recognition", | |
model=whisper_model, | |
chunk_length_s=30, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
) | |
MAX_NUM_FRAMES = 64 | |
def encode_image(image): | |
if not isinstance(image, Image.Image): | |
image = Image.open(image).convert("RGB") | |
max_size = 448*16 | |
if max(image.size) > max_size: | |
w,h = image.size | |
if w > h: | |
new_w = max_size | |
new_h = int(h * max_size / w) | |
else: | |
new_h = max_size | |
new_w = int(w * max_size / h) | |
image = image.resize((new_w, new_h), resample=Image.BICUBIC) | |
return image | |
def encode_video(video_path): | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
sample_fps = round(vr.get_avg_fps() / 1) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
if len(frame_idx) > MAX_NUM_FRAMES: | |
frame_idx = frame_idx[:MAX_NUM_FRAMES] | |
video = vr.get_batch(frame_idx).asnumpy() | |
video = [Image.fromarray(v.astype('uint8')) for v in video] | |
video = [encode_image(v) for v in video] | |
return video | |
def extract_audio(video_path): | |
video = mp.VideoFileClip(video_path) | |
audio_path = "temp_audio.wav" | |
video.audio.write_audiofile(audio_path) | |
return audio_path | |
def transcribe_audio(audio_file): | |
with open(audio_file, "rb") as f: | |
inputs = f.read() | |
inputs = ffmpeg_read(inputs, asr_pipeline.feature_extractor.sampling_rate) | |
inputs = {"array": inputs, "sampling_rate": asr_pipeline.feature_extractor.sampling_rate} | |
transcription = asr_pipeline(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] | |
return transcription | |
def analyze_video(prompt, video): | |
start_time = time.time() | |
if isinstance(video, str): | |
video_path = video | |
else: | |
video_path = video.name | |
encoded_video = encode_video(video_path) | |
# Extract audio and transcribe | |
audio_path = extract_audio(video_path) | |
transcription = transcribe_audio(audio_path) | |
# Clean up temporary audio file | |
os.remove(audio_path) | |
context = [ | |
{"role": "system", "content": f"Transcription of the video: {transcription}"}, | |
{"role": "user", "content": [prompt] + encoded_video} | |
] | |
params = { | |
'sampling': True, | |
'top_p': 0.8, | |
'top_k': 100, | |
'temperature': 0.7, | |
'repetition_penalty': 1.05, | |
"max_new_tokens": 2048, | |
"max_inp_length": 4352, | |
"use_image_id": False, | |
"max_slice_nums": 1 if len(encoded_video) > 16 else 2 | |
} | |
response = model.chat(image=None, msgs=context, tokenizer=tokenizer, **params) | |
end_time = time.time() | |
processing_time = end_time - start_time | |
analysis_result = f"Analysis Result:\n{response}\n\n" | |
processing_time = f"Processing Time: {processing_time:.2f} seconds" | |
return analysis_result, processing_time | |
with gr.Blocks() as demo: | |
gr.Markdown("# Video Analyzer") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt", value="What is the video about?") | |
video_input = gr.Video(label="Upload Video") | |
with gr.Column(): | |
analysis_result = gr.Textbox(label="Analysis Result") | |
processing_time = gr.Textbox(label="Processing Time") | |
analyze_button = gr.Button("Analyze Video") | |
analyze_button.click(fn=analyze_video, inputs=[prompt_input, video_input], outputs=[analysis_result, processing_time]) | |
demo.launch() | |