|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Thin wrappers and replacement classes for MistralForCausalLM |
|
""" |
|
from typing import Optional, Tuple, List, Union |
|
|
|
import warnings |
|
import torch |
|
import torch.nn as nn |
|
from transformers import MistralModel, MistralForCausalLM |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
from .modeling_llama import LolcatsLlamaModel |
|
from .convert_model import get_attention_cache |
|
|
|
|
|
|
|
class LolcatsMistralModel(LolcatsLlamaModel, MistralModel): |
|
""" |
|
Wrapper for Mistral-like autoregressive language model |
|
""" |
|
def forward(self, *args, **kwargs): |
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
class LolcatsMistralForCausalLM(MistralForCausalLM): |
|
""" |
|
Wrapper for Llama or Mistral-like autoregressive language model |
|
""" |
|
def __init__(self, config): |
|
|
|
if getattr(config, 'attention_bias', None) is None: |
|
config.attention_bias = False |
|
if getattr(config, 'rope_scaling', None) is None: |
|
config.rope_scaling = None |
|
if getattr(config, 'pretraining_tp', None) is None: |
|
config.pretraining_tp = 1 |
|
if getattr(config, 'pretraining_tp', None) is None: |
|
config.pretraining_tp = 1 |
|
if getattr(config, 'mlp_bias', None) is None: |
|
config.mlp_bias = False |
|
super().__init__(config) |
|
self.model = LolcatsMistralModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LooooolcatsMistralForCausalLM(LolcatsMistralForCausalLM): |
|
""" |
|
Wrapper for Llama or Mistral-like autoregressive language model |
|
-> Experimental / WIP; but goal is to combine chunked linear attention during training |
|
to process long contexts with minimally-growing memory usage |
|
""" |
|
def chunk_forward(self, *args: any, **kwargs: any): |
|
"""Call this when training / processing one chunk""" |
|
return super().forward(*args, **kwargs) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
""" |
|
Forward pass where we chunk inputs |
|
""" |
|
self.generating = False |
|
if use_cache is not True: |
|
use_cache = True |
|
|
|
if attention_mask is not None and use_cache: |
|
warnings.warn( |
|
f"Sorry padding currently not supported. Setting attention_mask to None (will still be causal)." |
|
) |
|
attention_mask = None |
|
|
|
if past_key_values is None: |
|
|
|
attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) |
|
past_key_values = get_attention_cache(attention_type) |
|
|
|
|
|
if input_ids.shape[-1] == 1 and not self.training: |
|
return super().forward(input_ids, attention_mask, position_ids, |
|
past_key_values, inputs_embeds, labels, |
|
use_cache, output_attentions, output_hidden_states, |
|
return_dict) |
|
else: |
|
if self.generating: |
|
self.generating = False |
|
|
|
attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) |
|
past_key_values = get_attention_cache(attention_type) |
|
print(f'-> attention_type:', attention_type) |
|
|
|
|
|
for idx in range(len(self.model.layers)): |
|
self.model.layers[idx].self_attn.state_grad_enabled = self.training |
|
|
|
|
|
input_ids = torch.split(input_ids, self.state_chunk_len, dim=-1) |
|
if position_ids is not None: |
|
position_ids = torch.split(position_ids, self.state_chunk_len, dim=-1) |
|
|
|
all_logits = [] |
|
for _idx, _input_ids in enumerate(input_ids): |
|
outputs = super().forward(_input_ids, None, |
|
position_ids[_idx] if position_ids is not None else None, |
|
past_key_values, inputs_embeds, |
|
labels=None, |
|
use_cache=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True,) |
|
past_key_values = outputs.past_key_values |
|
all_logits.append(outputs.logits) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _idx == len(input_ids) - 1: |
|
self.generating = True |
|
|
|
return CausalLMOutputWithPast( |
|
|
|
logits=torch.cat(all_logits, dim=-2), |
|
past_key_values=past_key_values, |
|
) |