Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import login | |
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline | |
import torch | |
import copy | |
import os | |
import spaces | |
import random | |
hf_token = os.environ.get("HF_TOKEN") | |
login(token = hf_token) | |
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) | |
def infer(lora_1_id, lora_1_sfts, lora_2_id, lora_2_sfts, prompt, negative_prompt, lora_1_scale, lora_2_scale, seed): | |
unet = copy.deepcopy(original_pipe.unet) | |
text_encoder = copy.deepcopy(original_pipe.text_encoder) | |
text_encoder_2 = copy.deepcopy(original_pipe.text_encoder_2) | |
pipe = StableDiffusionXLPipeline( | |
vae = original_pipe.vae, | |
text_encoder = text_encoder, | |
text_encoder_2 = text_encoder_2, | |
scheduler = original_pipe.scheduler, | |
tokenizer = original_pipe.tokenizer, | |
tokenizer_2 = original_pipe.tokenizer_2, | |
unet = unet | |
) | |
pipe.to("cuda") | |
pipe.load_lora_weights( | |
lora_1_id, | |
weight_name = lora_1_sfts, | |
low_cpu_mem_usage = True, | |
use_auth_token = True | |
) | |
pipe.fuse_lora(lora_1_scale) | |
pipe.load_lora_weights( | |
lora_2_id, | |
weight_name = lora_2_sfts, | |
low_cpu_mem_usage = True, | |
use_auth_token = True | |
) | |
pipe.fuse_lora(lora_2_scale) | |
if negative_prompt == "" : | |
negative_prompt = None | |
if seed < 0 : | |
seed = random.randit(0, 423538377342) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
image = pipe( | |
prompt = prompt, | |
negative_prompt = negative_prompt, | |
num_inference_steps = 25, | |
width = 1024, | |
height = 1024, | |
generator = generator | |
).images[0] | |
return image, seed | |
with gr.Blocks() as demo: | |
with gr.Column(elem_id="col-container"): | |
title = gr.HTML( | |
''' | |
<h1 style="text-align: center;">LoRA Fusion</h1> | |
<p style="text-align: center;">Fuse 2 custom LoRa models</p> | |
''' | |
) | |
# PART 1 • MODELS | |
with gr.Row(): | |
with gr.Column(): | |
lora_1_id = gr.Textbox( | |
label = "LoRa 1 ID", | |
placeholder = "username/model_id" | |
) | |
lora_1_sfts = gr.Textbox( | |
label = "Safetensors file", | |
placeholder = "specific_chosen.safetensors" | |
) | |
with gr.Column(): | |
lora_2_id = gr.Textbox( | |
label = "LoRa 2 ID", | |
placeholder = "username/model_id" | |
) | |
lora_2_sfts = gr.Textbox( | |
label = "Safetensors file", | |
placeholder = "specific_chosen.safetensors" | |
) | |
# PART 2 • INFERENCE | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label = "Your prompt", | |
info = "Use your trigger words into a coherent prompt" | |
placeholde = "e.g: a triggerWordOne portrait in triggerWord2 style" | |
) | |
run_btn = gr.Button("Run") | |
output_image = gr.Image( | |
label = "Output" | |
) | |
# Advanced Settings | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
lora_1_scale = gr.Slider( | |
label = "LoRa 1 scale", | |
minimum = 0, | |
maximum = 1, | |
steps = 0.1, | |
value = 0.7 | |
) | |
lora_2_scale = gr.Slider( | |
label = "LoRa 2 scale", | |
minimum = 0, | |
maximum = 1, | |
steps = 0.1, | |
value = 0.7 | |
) | |
negative_prompt = gr.Textbox( | |
label = "Negative prompt" | |
) | |
seed = gr.Slider( | |
label = "Seed", | |
info = "-1 denotes a random seed", | |
minimum = -1, | |
maximum = 423538377342, | |
value = -1 | |
) | |
last_used_seed = gr.Number( | |
label = Last used seed, | |
info = "the seed used in the last generation", | |
) | |
# ACTIONS | |
run_btn.click( | |
fn = infer, | |
inputs = [ | |
lora_1_id, | |
lora_1_sfts, | |
lora_2_id, | |
lora_2_sfts, | |
prompt, | |
negative_prompt, | |
lora_1_scale, | |
lora_2_scale, | |
seed | |
], | |
outputs = [ | |
output_image, | |
last_used_seed | |
] | |
) | |
demo.queue(concurrency_count=2).launch() | |