|
""" |
|
Classes for loading pretrained models |
|
""" |
|
from os.path import join |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import transformers |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer |
|
|
|
from peft import prepare_model_for_kbit_training |
|
|
|
|
|
def get_pretrained_loader(pretrained_model_name_or_path: str, |
|
huggingface_token: str = None, |
|
**model_kwargs: any): |
|
""" |
|
Return the appropriate loader for the pretrained model |
|
""" |
|
|
|
if 'lama' in pretrained_model_name_or_path: |
|
return PretrainedLlamaLoader( |
|
pretrained_model_name_or_path=pretrained_model_name_or_path, |
|
huggingface_token=huggingface_token, |
|
**model_kwargs, |
|
) |
|
elif 'istral' in pretrained_model_name_or_path: |
|
return PretrainedMistralLoader( |
|
pretrained_model_name_or_path=pretrained_model_name_or_path, |
|
huggingface_token=huggingface_token, |
|
**model_kwargs, |
|
) |
|
else: |
|
print(f'-> {pretrained_model_name_or_path} using default pretrained model loader') |
|
return PretrainedModelLoader( |
|
pretrained_model_name_or_path=pretrained_model_name_or_path, |
|
huggingface_token=huggingface_token, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
class PretrainedModelLoader(): |
|
""" |
|
Class for loading a pretrained model. |
|
Example: |
|
model_loader = PretrainedModelLoader(**model_kwargs) |
|
model = model_loader.load() |
|
""" |
|
def __init__(self, |
|
pretrained_model_name_or_path: str, |
|
cache_dir: str = None, |
|
return_dict: bool = True, |
|
device_map: str = 'auto', |
|
low_cpu_mem_usage: bool = True, |
|
torch_dtype: str = 'bfloat16', |
|
rope_theta: float = 10000., |
|
attn_implementation: str = 'sdpa', |
|
load_in_8bit: bool = False, |
|
load_in_4bit: bool = False, |
|
huggingface_token: str = None, |
|
peft_id: str = None, |
|
rope_scaling: dict = None, |
|
**other_kwargs: any) -> None: |
|
|
|
print(f'-> Using {attn_implementation} attention') |
|
|
|
self.loading_kwargs = { |
|
'pretrained_model_name_or_path': pretrained_model_name_or_path, |
|
'cache_dir': cache_dir, |
|
'return_dict': return_dict, |
|
'load_in_8bit': load_in_8bit, |
|
'load_in_4bit': load_in_4bit, |
|
'device_map': device_map, |
|
'low_cpu_mem_usage': low_cpu_mem_usage, |
|
'torch_dtype': getattr(torch, torch_dtype), |
|
'rope_theta': rope_theta, |
|
'attn_implementation': attn_implementation, |
|
} |
|
if rope_scaling is not None: |
|
rope_scaling = OmegaConf.to_container(rope_scaling) |
|
self.loading_kwargs['rope_scaling'] = rope_scaling |
|
for k, v in other_kwargs.items(): |
|
self.loading_kwargs[k] = v |
|
|
|
self.quantization = load_in_8bit or load_in_4bit |
|
self.peft_id = peft_id |
|
self.gradient_checkpointing = False |
|
if huggingface_token is not None: |
|
self.loading_kwargs['token'] = huggingface_token |
|
|
|
if self.quantization: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self) -> nn.Module: |
|
""" |
|
Load pretrained model |
|
""" |
|
model = AutoModelForCausalLM.from_pretrained(**self.loading_kwargs) |
|
if self.quantization: |
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=self.gradient_checkpointing, |
|
gradient_checkpointing_kwargs={'use_reentrant': False}, |
|
) |
|
return model |
|
|
|
def load_tokenizer(self): |
|
""" |
|
Load pretrained tokenizer |
|
""" |
|
try: |
|
return AutoTokenizer.from_pretrained(**self.loading_kwargs) |
|
except Exception as e: |
|
print("-> Error with `AutoTokenizer.from_pretrained(**self.loading_kwargs)`:", e) |
|
print("-> Trying `LlamaTokenizer.from_pretrained(**self.loading_kwargs)`") |
|
|
|
return LlamaTokenizer.from_pretrained(**self.loading_kwargs) |
|
|
|
|
|
class PretrainedLlamaLoader(PretrainedModelLoader): |
|
def load(self, model_type: str = None, ): |
|
llama3_1 = float('.'.join(transformers.__version__.split('.')[:2])) > 4.42 |
|
if model_type is None: |
|
from transformers import LlamaForCausalLM as model_class |
|
|
|
elif 'lolcats_llama_sharded' in model_type: |
|
from .modeling_llama_sharded import ShardedLolcatsLlamaForCausalLM as model_class |
|
|
|
elif 'lolcats_long_llama' in model_type: |
|
from .modeling_llama import LooooolcatsLlamaForCausalLM as model_class |
|
|
|
elif 'lolcats_llama' in model_type: |
|
from .modeling_llama import LolcatsLlamaForCausalLM as model_class |
|
|
|
else: |
|
if model_type == 'flash_attention_2': |
|
self.loading_kwargs['attn_implementation'] = model_type |
|
from transformers import AutoModelForCausalLM as model_class |
|
print('-> Loading from AutoModelForCausalLM') |
|
|
|
model = model_class.from_pretrained(**self.loading_kwargs) |
|
if self.peft_id is not None: |
|
from peft import PeftModel |
|
print('-> Loading PEFT checkpoint') |
|
model = PeftModel.from_pretrained( |
|
model, |
|
self.peft_id, |
|
torch_dtype=self.loading_kwargs['torch_dtype'], |
|
device_map='auto', |
|
cache_dir=self.loading_kwargs['cache_dir'] |
|
).merge_and_unload() |
|
|
|
if self.quantization: |
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=self.gradient_checkpointing, |
|
gradient_checkpointing_kwargs={'use_reentrant': False}, |
|
) |
|
return model |
|
|
|
def load_tokenizer(self): |
|
return AutoTokenizer.from_pretrained(**self.loading_kwargs) |
|
|
|
|
|
class PretrainedMistralLoader(PretrainedModelLoader): |
|
def load(self, model_type: str = None): |
|
if model_type is None: |
|
from transformers import MistralForCausalLM as model_class |
|
elif 'lolcats_long_llama' in model_type: |
|
from .modeling_mistral import LooooolcatsMistralForCausalLM as model_class |
|
elif 'lolcats_llama' in model_type: |
|
from .modeling_mistral import LolcatsMistralForCausalLM as model_class |
|
else: |
|
if model_type == 'flash_attention_2': |
|
self.loading_kwargs['attn_implementation'] = model_type |
|
from transformers import AutoModelForCausalLM as model_class |
|
print('-> Loading from AutoModelForCausalLM') |
|
|
|
model = model_class.from_pretrained(**self.loading_kwargs) |
|
if self.peft_id is not None: |
|
from peft import PeftModel |
|
model = PeftModel.from_pretrained( |
|
model, |
|
self.peft_id, |
|
torch_dtype=self.loading_kwargs['torch_dtype'], |
|
device_map='auto', |
|
cache_dir=self.loading_kwargs['cache_dir'], |
|
).merge_and_unload() |
|
|
|
if self.quantization: |
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=self.gradient_checkpointing, |
|
gradient_checkpointing_kwargs={'use_reentrant': False}, |
|
) |
|
return model |
|
|