|
|
|
import argparse |
|
import os |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from accelerate import load_checkpoint_and_dispatch, PartialState |
|
from accelerate.utils import gather_object |
|
from decord import VideoReader |
|
from PIL import Image |
|
from natsort import natsorted |
|
from tqdm import tqdm |
|
from transformers import AutoConfig, AutoTokenizer |
|
|
|
import tinychat.utils.constants |
|
|
|
from tinychat.models.vila_llama import VilaLlamaForCausalLM |
|
from tinychat.stream_generators.llava_stream_gen import LlavaStreamGenerator |
|
from tinychat.utils.conversation_utils import gen_params |
|
from tinychat.utils.llava_image_processing import process_images |
|
from tinychat.utils.prompt_templates import ( |
|
get_image_token, |
|
get_prompter, |
|
get_stop_token_ids, |
|
) |
|
from tinychat.utils.tune import ( |
|
device_warmup, |
|
tune_llava_patch_embedding, |
|
) |
|
|
|
from utils.filter import filter |
|
from utils.logger import logger |
|
|
|
gen_params.seed = 1 |
|
gen_params.temp = 1.0 |
|
gen_params.top_p = 1.0 |
|
|
|
|
|
def extract_uniform_frames(video_path: str, num_sampled_frames: int = 8): |
|
vr = VideoReader(video_path) |
|
sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) |
|
sampled_frame_list = [] |
|
for idx in sampled_frame_idx_list: |
|
sampled_frame = Image.fromarray(vr[idx].asnumpy()) |
|
sampled_frame_list.append(sampled_frame) |
|
|
|
return sampled_frame_list |
|
|
|
|
|
def stream_output(output_stream): |
|
for outputs in output_stream: |
|
output_text = outputs["text"] |
|
output_text = output_text.strip().split(" ") |
|
|
|
return " ".join(output_text) |
|
|
|
|
|
def skip(*args, **kwargs): |
|
pass |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Recaption videos with VILA1.5.") |
|
parser.add_argument( |
|
"--video_metadata_path", |
|
type=str, |
|
default=None, |
|
help="The path to the video dataset metadata (csv/jsonl).", |
|
) |
|
parser.add_argument( |
|
"--video_path_column", |
|
type=str, |
|
default="video_path", |
|
help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", |
|
) |
|
parser.add_argument( |
|
"--caption_column", |
|
type=str, |
|
default="caption", |
|
help="The column contains the caption.", |
|
) |
|
parser.add_argument( |
|
"--video_folder", type=str, default="", help="The video folder." |
|
) |
|
parser.add_argument("--input_prompt", type=str, default="<video>\\n Elaborate on the visual and narrative elements of the video in detail.") |
|
parser.add_argument( |
|
"--model_type", type=str, default="LLaMa", help="type of the model" |
|
) |
|
parser.add_argument( |
|
"--model_path", type=str, default="Efficient-Large-Model/Llama-3-VILA1.5-8b-AWQ" |
|
) |
|
parser.add_argument( |
|
"--quant_path", |
|
type=str, |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--precision", type=str, default="W4A16", help="compute precision" |
|
) |
|
parser.add_argument("--num_sampled_frames", type=int, default=8) |
|
parser.add_argument( |
|
"--saved_path", |
|
type=str, |
|
required=True, |
|
help="The save path to the output results (csv/jsonl).", |
|
) |
|
parser.add_argument( |
|
"--saved_freq", |
|
type=int, |
|
default=100, |
|
help="The frequency to save the output results.", |
|
) |
|
|
|
parser.add_argument( |
|
"--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)." |
|
) |
|
parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") |
|
parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") |
|
parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") |
|
parser.add_argument( |
|
"--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." |
|
) |
|
parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.") |
|
parser.add_argument( |
|
"--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." |
|
) |
|
parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.") |
|
parser.add_argument( |
|
"--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." |
|
) |
|
parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") |
|
parser.add_argument( |
|
"--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." |
|
) |
|
parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.") |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(args): |
|
if args.video_metadata_path.endswith(".csv"): |
|
video_metadata_df = pd.read_csv(args.video_metadata_path) |
|
elif args.video_metadata_path.endswith(".jsonl"): |
|
video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) |
|
else: |
|
raise ValueError("The video_metadata_path must end with .csv or .jsonl.") |
|
video_path_list = video_metadata_df[args.video_path_column].tolist() |
|
video_path_list = [os.path.basename(video_path) for video_path in video_path_list] |
|
|
|
if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): |
|
raise ValueError("The saved_path must end with .csv or .jsonl.") |
|
|
|
if os.path.exists(args.saved_path): |
|
if args.saved_path.endswith(".csv"): |
|
saved_metadata_df = pd.read_csv(args.saved_path) |
|
elif args.saved_path.endswith(".jsonl"): |
|
saved_metadata_df = pd.read_json(args.saved_path, lines=True) |
|
saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() |
|
video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) |
|
logger.info( |
|
f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed." |
|
) |
|
|
|
video_path_list = filter( |
|
video_path_list, |
|
basic_metadata_path=args.basic_metadata_path, |
|
min_resolution=args.min_resolution, |
|
min_duration=args.min_duration, |
|
max_duration=args.max_duration, |
|
asethetic_score_metadata_path=args.asethetic_score_metadata_path, |
|
min_asethetic_score=args.min_asethetic_score, |
|
asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path, |
|
min_asethetic_score_siglip=args.min_asethetic_score_siglip, |
|
text_score_metadata_path=args.text_score_metadata_path, |
|
min_text_score=args.min_text_score, |
|
motion_score_metadata_path=args.motion_score_metadata_path, |
|
min_motion_score=args.min_motion_score, |
|
) |
|
video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list] |
|
|
|
video_path_list = natsorted(video_path_list) |
|
|
|
state = PartialState() |
|
|
|
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
torch.nn.init.kaiming_uniform_ = skip |
|
torch.nn.init.kaiming_normal_ = skip |
|
torch.nn.init.uniform_ = skip |
|
torch.nn.init.normal_ = skip |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model_path, "llm"), use_fast=False) |
|
tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = ( |
|
tokenizer.convert_tokens_to_ids( |
|
[tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN] |
|
)[0] |
|
) |
|
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) |
|
model = VilaLlamaForCausalLM(config).half() |
|
tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = ( |
|
tokenizer.convert_tokens_to_ids( |
|
[tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN] |
|
)[0] |
|
) |
|
vision_tower = model.get_vision_tower() |
|
|
|
|
|
image_processor = vision_tower.image_processor |
|
|
|
|
|
if args.precision == "W16A16": |
|
pbar = tqdm(range(1)) |
|
pbar.set_description("Loading checkpoint shards") |
|
for i in pbar: |
|
model.llm = load_checkpoint_and_dispatch( |
|
model.llm, |
|
os.path.join(args.model_path, "llm"), |
|
no_split_module_classes=[ |
|
"OPTDecoderLayer", |
|
"LlamaDecoderLayer", |
|
"BloomBlock", |
|
"MPTBlock", |
|
"DecoderLayer", |
|
"CLIPEncoderLayer", |
|
], |
|
).to(state.device) |
|
model = model.to(state.device) |
|
|
|
elif args.precision == "W4A16": |
|
from tinychat.utils.load_quant import load_awq_model |
|
|
|
if args.quant_path is None: |
|
if "VILA1.5-3b-s2-AWQ" in args.model_path: |
|
args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-3b-s2-w4-g128-awq-v2.pt") |
|
elif "VILA1.5-3b-AWQ" in args.model_path: |
|
args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-3b-w4-g128-awq-v2.pt") |
|
elif "Llama-3-VILA1.5-8b-AWQ" in args.model_path: |
|
args.quant_path = os.path.join(args.model_path, "llm/llama-3-vila1.5-8b-w4-g128-awq-v2.pt") |
|
elif "VILA1.5-13b-AWQ" in args.model_path: |
|
args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-13b-w4-g128-awq-v2.pt") |
|
elif "VILA1.5-40b-AWQ" in args.model_path: |
|
args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-40b-w4-g128-awq-v2.pt") |
|
model.llm = load_awq_model(model.llm, args.quant_path, 4, 128, state.device) |
|
from tinychat.modules import ( |
|
make_fused_mlp, |
|
make_fused_vision_attn, |
|
make_quant_attn, |
|
make_quant_norm, |
|
) |
|
|
|
make_quant_attn(model.llm, state.device) |
|
make_quant_norm(model.llm) |
|
|
|
|
|
model = model.to(state.device) |
|
|
|
else: |
|
raise NotImplementedError(f"Precision {args.precision} is not supported.") |
|
|
|
device_warmup(state.device) |
|
tune_llava_patch_embedding(vision_tower, device=state.device) |
|
|
|
stream_generator = LlavaStreamGenerator |
|
|
|
model_prompter = get_prompter( |
|
args.model_type, args.model_path, False, False |
|
) |
|
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path) |
|
|
|
model.eval() |
|
|
|
index = len(video_path_list) - len(video_path_list) % state.num_processes |
|
|
|
logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to ensure each process handles the same number of videos.") |
|
video_path_list = video_path_list[:index] |
|
logger.info(f"{len(video_path_list)} videos are to be processed.") |
|
|
|
result_dict = {args.video_path_column: [], args.caption_column: []} |
|
with state.split_between_processes(video_path_list) as splitted_video_path_list: |
|
|
|
for i, video_path in enumerate(tqdm(splitted_video_path_list)): |
|
try: |
|
image_list = extract_uniform_frames(video_path, args.num_sampled_frames) |
|
image_num = len(image_list) |
|
|
|
image_tensor = process_images(image_list, image_processor, model.config) |
|
if type(image_tensor) is list: |
|
image_tensor = [ |
|
image.to(state.device, dtype=torch.float16) for image in image_tensor |
|
] |
|
else: |
|
image_tensor = image_tensor.to(state.device, dtype=torch.float16) |
|
|
|
input_prompt = args.input_prompt |
|
|
|
image_token = get_image_token(model, args.model_path) |
|
image_token_holder = tinychat.utils.constants.LLAVA_DEFAULT_IM_TOKEN_PLACE_HOLDER |
|
im_token_count = input_prompt.count(image_token_holder) |
|
if im_token_count == 0: |
|
model_prompter.insert_prompt(image_token * image_num + input_prompt) |
|
else: |
|
assert im_token_count == image_num |
|
input_prompt = input_prompt.replace(image_token_holder, image_token) |
|
model_prompter.insert_prompt(input_prompt) |
|
output_stream = stream_generator( |
|
model, |
|
tokenizer, |
|
model_prompter.model_input, |
|
gen_params, |
|
device=state.device, |
|
stop_token_ids=stop_token_ids, |
|
image_tensor=image_tensor, |
|
) |
|
outputs = stream_output(output_stream) |
|
if len(outputs) != 0: |
|
result_dict[args.video_path_column].append(Path(video_path).name) |
|
result_dict[args.caption_column].append(outputs) |
|
|
|
except Exception as e: |
|
logger.warning(f"VILA with {video_path} failed. Error is {e}.") |
|
|
|
if i != 0 and i % args.saved_freq == 0: |
|
state.wait_for_everyone() |
|
gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} |
|
if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0: |
|
result_df = pd.DataFrame(gathered_result_dict) |
|
if args.saved_path.endswith(".csv"): |
|
header = False if os.path.exists(args.saved_path) else True |
|
result_df.to_csv(args.saved_path, header=header, index=False, mode="a") |
|
elif args.saved_path.endswith(".jsonl"): |
|
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False) |
|
logger.info(f"Save result to {args.saved_path}.") |
|
for k in result_dict.keys(): |
|
result_dict[k] = [] |
|
|
|
state.wait_for_everyone() |
|
gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} |
|
if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0: |
|
result_df = pd.DataFrame(gathered_result_dict) |
|
if args.saved_path.endswith(".csv"): |
|
header = False if os.path.exists(args.saved_path) else True |
|
result_df.to_csv(args.saved_path, header=header, index=False, mode="a") |
|
elif args.saved_path.endswith(".jsonl"): |
|
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False) |
|
logger.info(f"Save result to {args.saved_path}.") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|