cocktailpeanut's picture
update
9dfb729
import base64
import os
import shutil
import tempfile
from io import BytesIO
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModel, AutoTokenizer
import devicetorch
#import spaces
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
device = devicetorch.get(torch)
new_path = 'Lin-Chen/ShareCaptioner-Video'
tokenizer = AutoTokenizer.from_pretrained(new_path, trust_remote_code=True)
model = AutoModel.from_pretrained(
#new_path, torch_dtype=torch.float16, trust_remote_code=True).cuda().eval()
new_path, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval()
#model.cuda()
model.to(device)
model.tokenizer = tokenizer
def padding_336(b, pad=336):
width, height = b.size
tar = int(np.ceil(height / pad) * pad)
top_padding = int((tar - height)/2)
bottom_padding = tar - height - top_padding
left_padding = 0
right_padding = 0
b = transforms.functional.pad(
b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255])
return b
def HD_transform(img, hd_num=25):
width, height = img.size
trans = False
if width < height:
img = img.transpose(Image.TRANSPOSE)
trans = True
width, height = img.size
ratio = (width / height)
scale = 1
while scale*np.ceil(scale/ratio) <= hd_num:
scale += 1
scale -= 1
new_w = int(scale * 336)
new_h = int(new_w / ratio)
img = transforms.functional.resize(img, [new_h, new_w],)
img = padding_336(img, 336)
width, height = img.size
if trans:
img = img.transpose(Image.TRANSPOSE)
return img
def get_seq_frames(total_num_frames, desired_num_frames, start=None, end=None):
if start is None:
assert end is None
start, end = 0, total_num_frames
print(f"{start=}, {end=}")
desired_num_frames -= 2
end = min(total_num_frames, end)
start = max(start, 0)
seg_size = float((end - start)) / desired_num_frames
seq = [start]
for i in range(desired_num_frames):
s = int(np.round(seg_size * i))
e = int(np.round(seg_size * (i + 1)))
seq.append(min(int(start + (s + e) // 2), total_num_frames-1))
return seq + [end-1]
def model_gen(model, text, images, need_bos=True, hd_num=25, max_new_token=2048, beam=3, do_sample=False):
pt1 = 0
embeds = []
im_mask = []
if images is None:
images = []
images_loc = []
else:
images = [images]
images_loc = [0]
for i, pts in enumerate(images_loc + [len(text)]):
subtext = text[pt1:pts]
if need_bos or len(subtext) > 0:
text_embeds = model.encode_text(
subtext, add_special_tokens=need_bos)
embeds.append(text_embeds)
#im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda())
im_mask.append(torch.zeros(text_embeds.shape[:2]).to(device))
need_bos = False
if i < len(images):
try:
image = Image.open(images[i]).convert('RGB')
except:
image = images[i].convert('RGB')
image = HD_transform(image, hd_num=hd_num)
#image = model.vis_processor(image).unsqueeze(0).cuda()
image = model.vis_processor(image).unsqueeze(0).to(device)
image_embeds = model.encode_img(image)
print(image_embeds.shape)
embeds.append(image_embeds)
#im_mask.append(torch.ones(image_embeds.shape[:2]).cuda())
im_mask.append(torch.ones(image_embeds.shape[:2]).to(device))
pt1 = pts
embeds = torch.cat(embeds, dim=1)
im_mask = torch.cat(im_mask, dim=1)
im_mask = im_mask.bool()
outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask,
temperature=1.0, max_new_tokens=max_new_token, num_beams=beam,
do_sample=False, repetition_penalty=1.00)
output_token = outputs[0]
if output_token[0] == 0 or output_token[0] == 1:
output_token = output_token[1:]
output_text = model.tokenizer.decode(
output_token, add_special_tokens=False)
output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip()
output_text = output_text.split('<|im_end|>')[0].strip()
return output_text
def img_process(imgs):
new_w = 0
new_h = 0
for im in imgs:
w, h = im.size
new_w = max(new_w, w)
new_h += h + 20
pad = max(new_w // 4, 100)
new_w += 20
new_h += 20
font = ImageFont.truetype("SimHei.ttf", pad // 5)
new_img = Image.new('RGB', (new_w + pad, new_h), 'white')
draw = ImageDraw.Draw(new_img)
curr_h = 10
for idx, im in enumerate(imgs):
w, h = im.size
new_img.paste(im, (pad, curr_h))
draw.text((0, curr_h + h // 2),
f'<IMAGE {idx}>', font=font, fill='black')
if idx + 1 < len(imgs):
draw.line([(0, curr_h + h + 10), (new_w+pad,
curr_h + h + 10)], fill='black', width=2)
curr_h += h + 20
return new_img
def load_quota_video(vis_path, start=None, end=None):
vr = VideoReader(vis_path)
total_frame_num = len(vr)
fps = vr.get_avg_fps()
if start is not None:
assert end is not None
start_frame = int(start * fps)
end_frame = min(int(end * fps), total_frame_num)
else:
start_frame = 0
end_frame = total_frame_num
interval = int(2 * fps)
frame_idx = list(range(start_frame, end_frame, interval))
img_array = vr.get_batch(frame_idx).asnumpy()
num_frm, H, W, _ = img_array.shape
img_array = img_array.reshape(
(1, num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1]))
clip_imgs = []
for j in range(num_frm):
clip_imgs.append(Image.fromarray(img_array[0, j]))
return clip_imgs
def resize_image(image_path, max_size=1024):
with Image.open(image_path) as img:
width, height = img.size
if width > max_size or height > max_size:
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
else:
new_width = width
new_height = height
resized_img = img.resize((new_width, new_height))
print(f"resized_img_size: {resized_img.size}")
return resized_img
def encode_resized_image(image_path, max_size=1024):
resized_img = resize_image(image_path, max_size)
try:
with BytesIO() as buffer:
resized_img.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
except:
with BytesIO() as buffer:
rgb_img = resized_img.convert('RGB')
rgb_img.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
#@spaces.GPU(duration=60)
def generate_slidingcaptioning(video_path):
imgs = load_quota_video(video_path)
q = 'This is the first frame of a video, describe it in detail.'
query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
img = imgs[0]
if device == "cuda":
with torch.cuda.amp.autocast():
response = model_gen(model, query, img, hd_num=9)
else:
response = model_gen(model, query, img, hd_num=9)
print(response)
responses = [response]
images = [img]
for idx in range(len(imgs)-1):
image1 = imgs[idx]
image2 = imgs[idx+1]
prompt = "Here are the Video frame {} at {}.00 Second(s) and Video frame {} at {}.00 Second(s) of a video, describe what happend between them. What happend before is: {}".format(
idx, int(idx*2), idx+1, int((idx+1)*2), response)
width, height = image1.size
new_img = Image.new('RGB', (width, 2*height+50), 'white')
new_img.paste(image1, (0, 0))
new_img.paste(image2, (0, height+50))
query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
if device == "cuda":
with torch.cuda.amp.autocast():
response = model_gen(model, query, new_img, hd_num=9)
else:
response = model_gen(model, query, new_img, hd_num=9)
responses.append(response)
images.append(new_img)
prompt = 'Summarize the following per frame descriptions:\n'
for idx, txt in enumerate(responses):
prompt += 'Video frame {} at {}.00 Second(s) description: {}\n'.format(
idx+1, idx*2, txt)
query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
print(query)
if device == "cuda":
with torch.cuda.amp.autocast():
summ = model_gen(model, query, None, hd_num=16)
else:
summ = model_gen(model, query, None, hd_num=16)
print(summ)
return summ
#@spaces.GPU(duration=60)
def generate_fastcaptioning(video_path):
q = 'Here are a few key frames of a video, discribe this video in detail.'
query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
imgs = load_quota_video(video_path)
img = img_process(imgs)
if device == "cuda":
with torch.cuda.amp.autocast():
response = model_gen(model, query, img, hd_num=16,
do_sample=False, beam=3)
else:
response = model_gen(model, query, img, hd_num=16,
do_sample=False, beam=3)
return response
#@spaces.GPU(duration=60)
def generate_promptrecaptioning(text):
q = f'Translate this brief generation prompt into a detailed caption: {text}'
query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
if device == "cuda":
with torch.cuda.amp.autocast():
response = model_gen(model, query, None)
else:
response = model_gen(model, query, None)
return response
def save_video_to_local(video_path):
filename = os.path.join('temp', next(
tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
state_ = gr.State()
first_run = gr.State()
with gr.Row():
gr.Markdown("### The ShareCaptioner-Video is a Four-in-One exceptional video captioning model with the following capabilities:\n1. Fast captioning, 2. Sliding Captioning, 3. Clip Summarizing, 4. Prompt Re-Captioning")
with gr.Row():
gr.Markdown("(THE DEMO OF \"Clip Summarizing\" IS COMING SOON...)")
with gr.Row():
with gr.Column(scale=6):
with gr.Row():
video = gr.Video(label="Input Video")
with gr.Row():
textbox = gr.Textbox(
show_label=False, placeholder="Input Text", container=False
)
with gr.Row():
with gr.Column(scale=2, min_width=50):
submit_btn_sc = gr.Button(
value="Sliding Captioning", variant="primary", interactive=True
)
with gr.Column(scale=2, min_width=50):
submit_btn_fc = gr.Button(
value="Fast Captioning", variant="primary", interactive=True
)
with gr.Column(scale=2, min_width=50):
submit_btn_pr = gr.Button(
value="Prompt Re-captioning", variant="primary", interactive=True
)
with gr.Column(scale=4, min_width=200):
with gr.Row():
textbox_out = gr.Textbox(
show_label=False, placeholder="Output", container=False
)
submit_btn_sc.click(generate_slidingcaptioning, [video], [textbox_out])
submit_btn_fc.click(generate_fastcaptioning, [video], [textbox_out])
submit_btn_pr.click(generate_promptrecaptioning, [textbox], [textbox_out])
demo.launch()