videollm-online / models /arguments_live.py
chenjoya's picture
Upload 9 files
7d1b5a5 verified
raw
history blame
No virus
2.12 kB
from dataclasses import dataclass, field
from transformers import TrainingArguments
@dataclass
class LiveTrainingArguments(TrainingArguments):
live_version: str = 'live1+'
system_prompt: str = (
"A multimodal AI assistant is helping users with some activities."
" Below is their conversation, interleaved with the list of video frames received by the assistant."
)
train_datasets: list[str] = None
eval_datasets: list[str] = None
stream_loss_weight: float = 1.0
llm_pretrained: str = 'meta-llama/Meta-Llama-3-8B-Instruct'
vision_pretrained: str = 'google/siglip-large-patch16-384'
lora_modules: str = "model.*(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)|lm_head$"
lora_r: int = 128
lora_alpha: int = 256
finetune_modules: list[str] = field(default_factory=lambda: ['connector'])
frame_fps: int = 2 # for training. inference can be 10
frame_token_cls: bool = None
frame_token_pooled: list[int] = None
frame_resolution: int = 384
frame_token_interval: str = None
frame_token_interval_threshold: float = 0.0
augmentation: bool = False
attn_implementation: str = 'flash_attention_2'
output_dir: str = 'outputs/debug'
@dataclass
class LiveOneTrainingArguments(LiveTrainingArguments):
live_version: str = 'live1'
frame_token_cls: bool = True
frame_num_tokens: int = 1
frame_token_interval: str = ''
embed_mark: str = '2fps_384_1'
max_num_frames: int = 7200 # 1h, 2fps, 7200 frames
@dataclass
class LiveOnePlusTrainingArguments(LiveTrainingArguments):
live_version: str = 'live1+'
frame_token_cls: bool = True
frame_token_pooled: list[int] = field(default_factory=lambda: [3,3])
frame_num_tokens: int = 10 # 1+3x3
embed_mark: str = '2fps_384_1+3x3'
frame_token_interval: str = ','
max_num_frames: int = 1200 # 10min, 2fps, 1200 frames
def get_args_class(live_version: str):
if live_version == 'live1':
return LiveOneTrainingArguments
elif live_version == 'live1+':
return LiveOnePlusTrainingArguments
raise NotImplementedError