Spaces:
Runtime error
Runtime error
import base64 | |
import os | |
import shutil | |
import tempfile | |
from io import BytesIO | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from decord import VideoReader | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoModel, AutoTokenizer | |
import devicetorch | |
#import spaces | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
device = devicetorch.get(torch) | |
new_path = 'Lin-Chen/ShareCaptioner-Video' | |
tokenizer = AutoTokenizer.from_pretrained(new_path, trust_remote_code=True) | |
model = AutoModel.from_pretrained( | |
#new_path, torch_dtype=torch.float16, trust_remote_code=True).cuda().eval() | |
new_path, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval() | |
#model.cuda() | |
model.to(device) | |
model.tokenizer = tokenizer | |
def padding_336(b, pad=336): | |
width, height = b.size | |
tar = int(np.ceil(height / pad) * pad) | |
top_padding = int((tar - height)/2) | |
bottom_padding = tar - height - top_padding | |
left_padding = 0 | |
right_padding = 0 | |
b = transforms.functional.pad( | |
b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255]) | |
return b | |
def HD_transform(img, hd_num=25): | |
width, height = img.size | |
trans = False | |
if width < height: | |
img = img.transpose(Image.TRANSPOSE) | |
trans = True | |
width, height = img.size | |
ratio = (width / height) | |
scale = 1 | |
while scale*np.ceil(scale/ratio) <= hd_num: | |
scale += 1 | |
scale -= 1 | |
new_w = int(scale * 336) | |
new_h = int(new_w / ratio) | |
img = transforms.functional.resize(img, [new_h, new_w],) | |
img = padding_336(img, 336) | |
width, height = img.size | |
if trans: | |
img = img.transpose(Image.TRANSPOSE) | |
return img | |
def get_seq_frames(total_num_frames, desired_num_frames, start=None, end=None): | |
if start is None: | |
assert end is None | |
start, end = 0, total_num_frames | |
print(f"{start=}, {end=}") | |
desired_num_frames -= 2 | |
end = min(total_num_frames, end) | |
start = max(start, 0) | |
seg_size = float((end - start)) / desired_num_frames | |
seq = [start] | |
for i in range(desired_num_frames): | |
s = int(np.round(seg_size * i)) | |
e = int(np.round(seg_size * (i + 1))) | |
seq.append(min(int(start + (s + e) // 2), total_num_frames-1)) | |
return seq + [end-1] | |
def model_gen(model, text, images, need_bos=True, hd_num=25, max_new_token=2048, beam=3, do_sample=False): | |
pt1 = 0 | |
embeds = [] | |
im_mask = [] | |
if images is None: | |
images = [] | |
images_loc = [] | |
else: | |
images = [images] | |
images_loc = [0] | |
for i, pts in enumerate(images_loc + [len(text)]): | |
subtext = text[pt1:pts] | |
if need_bos or len(subtext) > 0: | |
text_embeds = model.encode_text( | |
subtext, add_special_tokens=need_bos) | |
embeds.append(text_embeds) | |
#im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda()) | |
im_mask.append(torch.zeros(text_embeds.shape[:2]).to(device)) | |
need_bos = False | |
if i < len(images): | |
try: | |
image = Image.open(images[i]).convert('RGB') | |
except: | |
image = images[i].convert('RGB') | |
image = HD_transform(image, hd_num=hd_num) | |
#image = model.vis_processor(image).unsqueeze(0).cuda() | |
image = model.vis_processor(image).unsqueeze(0).to(device) | |
image_embeds = model.encode_img(image) | |
print(image_embeds.shape) | |
embeds.append(image_embeds) | |
#im_mask.append(torch.ones(image_embeds.shape[:2]).cuda()) | |
im_mask.append(torch.ones(image_embeds.shape[:2]).to(device)) | |
pt1 = pts | |
embeds = torch.cat(embeds, dim=1) | |
im_mask = torch.cat(im_mask, dim=1) | |
im_mask = im_mask.bool() | |
outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask, | |
temperature=1.0, max_new_tokens=max_new_token, num_beams=beam, | |
do_sample=False, repetition_penalty=1.00) | |
output_token = outputs[0] | |
if output_token[0] == 0 or output_token[0] == 1: | |
output_token = output_token[1:] | |
output_text = model.tokenizer.decode( | |
output_token, add_special_tokens=False) | |
output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip() | |
output_text = output_text.split('<|im_end|>')[0].strip() | |
return output_text | |
def img_process(imgs): | |
new_w = 0 | |
new_h = 0 | |
for im in imgs: | |
w, h = im.size | |
new_w = max(new_w, w) | |
new_h += h + 20 | |
pad = max(new_w // 4, 100) | |
new_w += 20 | |
new_h += 20 | |
font = ImageFont.truetype("SimHei.ttf", pad // 5) | |
new_img = Image.new('RGB', (new_w + pad, new_h), 'white') | |
draw = ImageDraw.Draw(new_img) | |
curr_h = 10 | |
for idx, im in enumerate(imgs): | |
w, h = im.size | |
new_img.paste(im, (pad, curr_h)) | |
draw.text((0, curr_h + h // 2), | |
f'<IMAGE {idx}>', font=font, fill='black') | |
if idx + 1 < len(imgs): | |
draw.line([(0, curr_h + h + 10), (new_w+pad, | |
curr_h + h + 10)], fill='black', width=2) | |
curr_h += h + 20 | |
return new_img | |
def load_quota_video(vis_path, start=None, end=None): | |
vr = VideoReader(vis_path) | |
total_frame_num = len(vr) | |
fps = vr.get_avg_fps() | |
if start is not None: | |
assert end is not None | |
start_frame = int(start * fps) | |
end_frame = min(int(end * fps), total_frame_num) | |
else: | |
start_frame = 0 | |
end_frame = total_frame_num | |
interval = int(2 * fps) | |
frame_idx = list(range(start_frame, end_frame, interval)) | |
img_array = vr.get_batch(frame_idx).asnumpy() | |
num_frm, H, W, _ = img_array.shape | |
img_array = img_array.reshape( | |
(1, num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1])) | |
clip_imgs = [] | |
for j in range(num_frm): | |
clip_imgs.append(Image.fromarray(img_array[0, j])) | |
return clip_imgs | |
def resize_image(image_path, max_size=1024): | |
with Image.open(image_path) as img: | |
width, height = img.size | |
if width > max_size or height > max_size: | |
if width > height: | |
new_width = max_size | |
new_height = int(height * (max_size / width)) | |
else: | |
new_height = max_size | |
new_width = int(width * (max_size / height)) | |
else: | |
new_width = width | |
new_height = height | |
resized_img = img.resize((new_width, new_height)) | |
print(f"resized_img_size: {resized_img.size}") | |
return resized_img | |
def encode_resized_image(image_path, max_size=1024): | |
resized_img = resize_image(image_path, max_size) | |
try: | |
with BytesIO() as buffer: | |
resized_img.save(buffer, format="JPEG") | |
return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
except: | |
with BytesIO() as buffer: | |
rgb_img = resized_img.convert('RGB') | |
rgb_img.save(buffer, format="JPEG") | |
return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
#@spaces.GPU(duration=60) | |
def generate_slidingcaptioning(video_path): | |
imgs = load_quota_video(video_path) | |
q = 'This is the first frame of a video, describe it in detail.' | |
query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
img = imgs[0] | |
if device == "cuda": | |
with torch.cuda.amp.autocast(): | |
response = model_gen(model, query, img, hd_num=9) | |
else: | |
response = model_gen(model, query, img, hd_num=9) | |
print(response) | |
responses = [response] | |
images = [img] | |
for idx in range(len(imgs)-1): | |
image1 = imgs[idx] | |
image2 = imgs[idx+1] | |
prompt = "Here are the Video frame {} at {}.00 Second(s) and Video frame {} at {}.00 Second(s) of a video, describe what happend between them. What happend before is: {}".format( | |
idx, int(idx*2), idx+1, int((idx+1)*2), response) | |
width, height = image1.size | |
new_img = Image.new('RGB', (width, 2*height+50), 'white') | |
new_img.paste(image1, (0, 0)) | |
new_img.paste(image2, (0, height+50)) | |
query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
if device == "cuda": | |
with torch.cuda.amp.autocast(): | |
response = model_gen(model, query, new_img, hd_num=9) | |
else: | |
response = model_gen(model, query, new_img, hd_num=9) | |
responses.append(response) | |
images.append(new_img) | |
prompt = 'Summarize the following per frame descriptions:\n' | |
for idx, txt in enumerate(responses): | |
prompt += 'Video frame {} at {}.00 Second(s) description: {}\n'.format( | |
idx+1, idx*2, txt) | |
query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
print(query) | |
if device == "cuda": | |
with torch.cuda.amp.autocast(): | |
summ = model_gen(model, query, None, hd_num=16) | |
else: | |
summ = model_gen(model, query, None, hd_num=16) | |
print(summ) | |
return summ | |
#@spaces.GPU(duration=60) | |
def generate_fastcaptioning(video_path): | |
q = 'Here are a few key frames of a video, discribe this video in detail.' | |
query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
imgs = load_quota_video(video_path) | |
img = img_process(imgs) | |
if device == "cuda": | |
with torch.cuda.amp.autocast(): | |
response = model_gen(model, query, img, hd_num=16, | |
do_sample=False, beam=3) | |
else: | |
response = model_gen(model, query, img, hd_num=16, | |
do_sample=False, beam=3) | |
return response | |
#@spaces.GPU(duration=60) | |
def generate_promptrecaptioning(text): | |
q = f'Translate this brief generation prompt into a detailed caption: {text}' | |
query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' | |
if device == "cuda": | |
with torch.cuda.amp.autocast(): | |
response = model_gen(model, query, None) | |
else: | |
response = model_gen(model, query, None) | |
return response | |
def save_video_to_local(video_path): | |
filename = os.path.join('temp', next( | |
tempfile._get_candidate_names()) + '.mp4') | |
shutil.copyfile(video_path, filename) | |
return filename | |
with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=block_css) as demo: | |
state = gr.State() | |
state_ = gr.State() | |
first_run = gr.State() | |
with gr.Row(): | |
gr.Markdown("### The ShareCaptioner-Video is a Four-in-One exceptional video captioning model with the following capabilities:\n1. Fast captioning, 2. Sliding Captioning, 3. Clip Summarizing, 4. Prompt Re-Captioning") | |
with gr.Row(): | |
gr.Markdown("(THE DEMO OF \"Clip Summarizing\" IS COMING SOON...)") | |
with gr.Row(): | |
with gr.Column(scale=6): | |
with gr.Row(): | |
video = gr.Video(label="Input Video") | |
with gr.Row(): | |
textbox = gr.Textbox( | |
show_label=False, placeholder="Input Text", container=False | |
) | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=50): | |
submit_btn_sc = gr.Button( | |
value="Sliding Captioning", variant="primary", interactive=True | |
) | |
with gr.Column(scale=2, min_width=50): | |
submit_btn_fc = gr.Button( | |
value="Fast Captioning", variant="primary", interactive=True | |
) | |
with gr.Column(scale=2, min_width=50): | |
submit_btn_pr = gr.Button( | |
value="Prompt Re-captioning", variant="primary", interactive=True | |
) | |
with gr.Column(scale=4, min_width=200): | |
with gr.Row(): | |
textbox_out = gr.Textbox( | |
show_label=False, placeholder="Output", container=False | |
) | |
submit_btn_sc.click(generate_slidingcaptioning, [video], [textbox_out]) | |
submit_btn_fc.click(generate_fastcaptioning, [video], [textbox_out]) | |
submit_btn_pr.click(generate_promptrecaptioning, [textbox], [textbox_out]) | |
demo.launch() | |