File size: 1,961 Bytes
097d1a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer, AutoModel
from diffusers import StableDiffusionPipeline
import torch

# Paths to your LoRA files
file1_path = "lora.TA_trained (2).safetensors"
file2_path = "my_first_flux_lora_v1.safetensors"
file3_path = "NSFW_master.safetensors"
merged_file_path = "merged_lora.safetensors"

# Define weights for each LoRA
weight1 = 0.8  # Adjust this to control the influence of file1
weight2 = 0.2  # Adjust this to control the influence of file2

def load_and_weight_tensors(file_path, weight):
    tensors = load_file(file_path)
    # Apply weight to each tensor
    weighted_tensors = {key: weight * tensor for key, tensor in tensors.items()}
    return weighted_tensors

try:
    # Load and weight each tensor dictionary
    tensors1 = load_and_weight_tensors(file1_path, weight1)
    tensors2 = load_and_weight_tensors(file2_path, weight2)

    # Merge weighted tensors
    merged_tensors = {**tensors1}
    for key in tensors2:
        if key in merged_tensors:
            merged_tensors[key] += tensors2[key]
        else:
            merged_tensors[key] = tensors2[key]

    # Save the weighted merged tensors
    save_file(merged_tensors, merged_file_path)
    print(f"Merged file with weights saved at: {merged_file_path}")

    # Validate the merged file
    merged_load = load_file(merged_file_path)
    print("Keys in merged file:", merged_load.keys())

    try:
        # Load the tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained("Tjay143/Hijab2")
        model = AutoModel.from_pretrained("Tjay143/Hijab2", from_tf=False, from_safetensors=True)
        pipeline = StableDiffusionPipeline.from_pretrained("Tjay143/Hijab2", torch_dtype=torch.float16)
        print("Pipeline loaded successfully!")
    except Exception as e:
        print(f"Error loading model/pipeline: {e}")
except Exception as e:
    print(f"An error occurred: {e}")