File size: 6,448 Bytes
564df58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
from os import path
from backend.lora import (
    get_lora_models,
    get_active_lora_weights,
    update_lora_weights,
    load_lora_weight,
)
from state import get_settings, get_context
from frontend.utils import get_valid_lora_model
from models.interface_types import InterfaceType
from backend.models.lcmdiffusion_setting import LCMDiffusionSetting


_MAX_LORA_WEIGHTS = 5

_custom_lora_sliders = []
_custom_lora_names = []
_custom_lora_columns = []

app_settings = get_settings()


def on_click_update_weight(*lora_weights):
    update_weights = []
    active_weights = get_active_lora_weights()
    if not len(active_weights):
        gr.Warning("No active LoRAs, first you need to load LoRA model")
        return
    for idx, lora in enumerate(active_weights):
        update_weights.append(
            (
                lora[0],
                lora_weights[idx],
            )
        )
    if len(update_weights) > 0:
        update_lora_weights(
            get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline,
            app_settings.settings.lcm_diffusion_setting,
            update_weights,
        )


def on_click_load_lora(lora_name, lora_weight):
    if app_settings.settings.lcm_diffusion_setting.use_openvino:
        gr.Warning("Currently LoRA is not supported in OpenVINO.")
        return
    lora_models_map = get_lora_models(
        app_settings.settings.lcm_diffusion_setting.lora.models_dir
    )

    # Load a new LoRA
    settings = app_settings.settings.lcm_diffusion_setting
    settings.lora.fuse = False
    settings.lora.enabled = False
    settings.lora.path = lora_models_map[lora_name]
    settings.lora.weight = lora_weight
    if not path.exists(settings.lora.path):
        gr.Warning("Invalid LoRA model path!")
        return
    pipeline = get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline
    if not pipeline:
        gr.Warning("Pipeline not initialized. Please generate an image first.")
        return
    settings.lora.enabled = True
    load_lora_weight(
        get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline,
        settings,
    )

    # Update Gradio LoRA UI
    global _MAX_LORA_WEIGHTS
    values = []
    labels = []
    rows = []
    active_weights = get_active_lora_weights()
    for idx, lora in enumerate(active_weights):
        labels.append(f"{lora[0]}: ")
        values.append(lora[1])
        rows.append(gr.Row.update(visible=True))
    for i in range(len(active_weights), _MAX_LORA_WEIGHTS):
        labels.append(f"Update weight")
        values.append(0.0)
        rows.append(gr.Row.update(visible=False))
    return labels + values + rows


def get_lora_models_ui() -> None:
    with gr.Blocks() as ui:
        gr.HTML(
            "Download and place your LoRA model weights in <b>lora_models</b> folders and restart App"
        )
        with gr.Row():

            with gr.Column():
                with gr.Row():
                    lora_models_map = get_lora_models(
                        app_settings.settings.lcm_diffusion_setting.lora.models_dir
                    )
                    valid_model = get_valid_lora_model(
                        list(lora_models_map.values()),
                        app_settings.settings.lcm_diffusion_setting.lora.path,
                        app_settings.settings.lcm_diffusion_setting.lora.models_dir,
                    )
                    if valid_model != "":
                        valid_model_path = lora_models_map[valid_model]
                        app_settings.settings.lcm_diffusion_setting.lora.path = (
                            valid_model_path
                        )
                    else:
                        app_settings.settings.lcm_diffusion_setting.lora.path = ""

                    lora_model = gr.Dropdown(
                        lora_models_map.keys(),
                        label="LoRA model",
                        info="LoRA model weight to load (You can use Lora models from Civitai or Hugging Face .safetensors format)",
                        value=valid_model,
                        interactive=True,
                    )

                    lora_weight = gr.Slider(
                        0.0,
                        1.0,
                        value=app_settings.settings.lcm_diffusion_setting.lora.weight,
                        step=0.05,
                        label="Initial Lora weight",
                        interactive=True,
                    )
                    load_lora_btn = gr.Button(
                        "Load selected LoRA",
                        elem_id="load_lora_button",
                        scale=0,
                    )

                with gr.Row():
                    gr.Markdown(
                        "## Loaded LoRA models",
                        show_label=False,
                    )
                    update_lora_weights_btn = gr.Button(
                        "Update LoRA weights",
                        elem_id="load_lora_button",
                        scale=0,
                    )

                global _MAX_LORA_WEIGHTS
                global _custom_lora_sliders
                global _custom_lora_names
                global _custom_lora_columns
                for i in range(0, _MAX_LORA_WEIGHTS):
                    new_row = gr.Column(visible=False)
                    _custom_lora_columns.append(new_row)
                    with new_row:
                        lora_name = gr.Markdown(
                            "Lora Name",
                            show_label=True,
                        )
                        lora_slider = gr.Slider(
                            0.0,
                            1.0,
                            step=0.05,
                            label="LoRA weight",
                            interactive=True,
                            visible=True,
                        )

                        _custom_lora_names.append(lora_name)
                        _custom_lora_sliders.append(lora_slider)

    load_lora_btn.click(
        fn=on_click_load_lora,
        inputs=[lora_model, lora_weight],
        outputs=[
            *_custom_lora_names,
            *_custom_lora_sliders,
            *_custom_lora_columns,
        ],
    )

    update_lora_weights_btn.click(
        fn=on_click_update_weight,
        inputs=[*_custom_lora_sliders],
        outputs=None,
    )