import torch from torch import nn from transformers import LlamaForCausalLM, Cache from transformers.activations import GELUActivation from transformers.utils import logging from .configuration_live_llama import LiveLlamaConfig from ..modeling_live import build_live, LiveMixin logger = logging.get_logger(__name__) class LiveLlamaForCausalLM(LlamaForCausalLM, LiveMixin): config_class = LiveLlamaConfig _keys_to_ignore_on_load_missing = ['vision_encoder', 'connector'] def __init__(self, config: LiveLlamaConfig): super().__init__(config) self.connector = torch.nn.Sequential( torch.nn.Linear(config.vision_hidden_size, config.hidden_size, bias=True), GELUActivation(config.hidden_size), torch.nn.Linear(config.hidden_size, config.hidden_size, bias=True), ) def forward( self, input_ids: torch.LongTensor = None, frames: torch.FloatTensor = None, attention_mask: torch.Tensor = None, position_ids: torch.LongTensor = None, past_key_values: list[torch.FloatTensor] = None, inputs_embeds: torch.FloatTensor = None, labels: torch.LongTensor = None, use_cache: bool = None, output_attentions: bool = None, output_hidden_states: bool = None, return_dict: bool = None, cache_position: torch.LongTensor = None, **kwargs, ): if inputs_embeds is None: inputs_embeds = self.joint_embed(input_ids, frames) outputs = super().forward( attention_mask = attention_mask, position_ids = position_ids, past_key_values = past_key_values, inputs_embeds = inputs_embeds, # labels use_cache = use_cache, output_attentions = output_attentions, output_hidden_states = output_hidden_states, return_dict = return_dict, cache_position=cache_position, ) loss = None if labels is not None: logits = outputs[0] v_mask = input_ids.flatten(0, 1) == self.config.v_placeholder_id weight = v_mask * self.config.stream_loss_weight + ~v_mask loss = nn.functional.cross_entropy(logits.flatten(0, 1), labels.flatten(), reduction='none') * weight loss = loss.sum() / (labels >= 0).sum() if not return_dict: return (loss,) + outputs[1:] if loss is not None else outputs outputs.loss = loss return outputs def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, **kwargs, ): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, past_length :] # NOTE # NOTE if inputs_embeds is not None and past_length < inputs_embeds.size(1): model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) elif use_cache: cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, # 长度为新的inputs,从past开始 "cache_position": cache_position, # 没有被cache的区域 "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, # cache + input的长度 } ) return model_inputs def build_live_llama(**kwargs): return build_live(config_class=LiveLlamaConfig, model_class=LiveLlamaForCausalLM, **kwargs) if __name__ == '__main__': from ..arguments_live import LiveOnePlusTrainingArguments print(LiveOnePlusTrainingArguments().to_dict()) model, tokenizer = build_live_llama(is_training=True, **LiveOnePlusTrainingArguments().to_dict()) print(model.config, tokenizer)