hybridModel / hybridModel.py
SirWumpus's picture
Create hybridModel.py
4ad6480 verified
raw
history blame contribute delete
No virus
1 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# Load base model
base_model_name = "NousResearch/Llama-2-13b-hf"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load LoRA weights
lora_model_name = "FinGPT/fingpt-sentiment_llama2-13b_lora"
lora_model = AutoModelForCausalLM.from_pretrained(lora_model_name)
# Apply LoRA weights to the base model
def apply_lora_weights(base_model, lora_model):
base_model_state_dict = base_model.state_dict()
lora_model_state_dict = lora_model.state_dict()
for name, param in lora_model_state_dict.items():
if name in base_model_state_dict:
base_model_state_dict[name].copy_(param)
base_model.load_state_dict(base_model_state_dict)
apply_lora_weights(base_model, lora_model)
# Save the merged model
output_dir = "./hybrid_model"
base_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)