VideoAnalyzer / app.py
Zeph27's picture
update
673f8e1
raw
history blame
4.19 kB
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer, pipeline
from PIL import Image
from decord import VideoReader, cpu
import base64
import io
import spaces
import time
import os
from transformers.pipelines.audio_utils import ffmpeg_read
import moviepy.editor as mp
# Load models
model_path = 'openbmb/MiniCPM-V-2_6'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to(device='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
# Load Whisper model
whisper_model = "openai/whisper-large-v3"
asr_pipeline = pipeline(
task="automatic-speech-recognition",
model=whisper_model,
chunk_length_s=30,
device="cuda" if torch.cuda.is_available() else "cpu",
)
MAX_NUM_FRAMES = 64
def encode_image(image):
if not isinstance(image, Image.Image):
image = Image.open(image).convert("RGB")
max_size = 448*16
if max(image.size) > max_size:
w,h = image.size
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
return image
def encode_video(video_path):
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = frame_idx[:MAX_NUM_FRAMES]
video = vr.get_batch(frame_idx).asnumpy()
video = [Image.fromarray(v.astype('uint8')) for v in video]
video = [encode_image(v) for v in video]
return video
def extract_audio(video_path):
video = mp.VideoFileClip(video_path)
audio_path = "temp_audio.wav"
video.audio.write_audiofile(audio_path)
return audio_path
def transcribe_audio(audio_file):
with open(audio_file, "rb") as f:
inputs = f.read()
inputs = ffmpeg_read(inputs, asr_pipeline.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": asr_pipeline.feature_extractor.sampling_rate}
transcription = asr_pipeline(inputs, batch_size=8, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
return transcription
@spaces.GPU
def analyze_video(prompt, video):
start_time = time.time()
if isinstance(video, str):
video_path = video
else:
video_path = video.name
encoded_video = encode_video(video_path)
# Extract audio and transcribe
audio_path = extract_audio(video_path)
transcription = transcribe_audio(audio_path)
# Clean up temporary audio file
os.remove(audio_path)
context = [
{"role": "system", "content": f"Transcription of the video: {transcription}"},
{"role": "user", "content": [prompt] + encoded_video}
]
params = {
'sampling': True,
'top_p': 0.8,
'top_k': 100,
'temperature': 0.7,
'repetition_penalty': 1.05,
"max_new_tokens": 2048,
"max_inp_length": 4352,
"use_image_id": False,
"max_slice_nums": 1 if len(encoded_video) > 16 else 2
}
response = model.chat(image=None, msgs=context, tokenizer=tokenizer, **params)
end_time = time.time()
processing_time = end_time - start_time
analysis_result = f"Analysis Result:\n{response}\n\n"
processing_time = f"Processing Time: {processing_time:.2f} seconds"
return analysis_result, processing_time
with gr.Blocks() as demo:
gr.Markdown("# Video Analyzer")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", value="What is the video about?")
video_input = gr.Video(label="Upload Video")
with gr.Column():
analysis_result = gr.Textbox(label="Analysis Result")
processing_time = gr.Textbox(label="Processing Time")
analyze_button = gr.Button("Analyze Video")
analyze_button.click(fn=analyze_video, inputs=[prompt_input, video_input], outputs=[analysis_result, processing_time])
demo.launch()