Spaces:
Runtime error
Runtime error
File size: 7,089 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# https://github.com/showlab/X-Adapter
import torch
import diffusers
import gradio as gr
import huggingface_hub as hf
from modules import errors, shared, devices, scripts, processing, sd_models, sd_samplers
adapter = None
class Script(scripts.Script):
def title(self):
return 'X-Adapter'
def show(self, is_img2img):
return False
# return True if shared.backend == shared.Backend.DIFFUSERS else False
def ui(self, _is_img2img):
with gr.Row():
gr.HTML('<a href="https://github.com/showlab/X-Adapter">  X-Adapter</a><br>')
with gr.Row():
model = gr.Dropdown(label='Adapter model', choices=['None'] + sd_models.checkpoint_tiles(), value='None')
sampler = gr.Dropdown(label='Adapter sampler', choices=[s.name for s in sd_samplers.samplers], value='Default')
with gr.Row():
width = gr.Slider(label='Adapter width', minimum=64, maximum=2048, step=8, value=512)
height = gr.Slider(label='Adapter height', minimum=64, maximum=2048, step=8, value=512)
with gr.Row():
start = gr.Slider(label='Adapter start', minimum=0.0, maximum=1.0, step=0.01, value=0.5)
scale = gr.Slider(label='Adapter scale', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
with gr.Row():
lora = gr.Textbox('', label='Adapter LoRA', default='')
return model, sampler, width, height, start, scale, lora
def run(self, p: processing.StableDiffusionProcessing, model, sampler, width, height, start, scale, lora): # pylint: disable=arguments-differ
from modules.xadapter.xadapter_hijacks import PositionNet
diffusers.models.embeddings.PositionNet = PositionNet # patch diffusers==0.26 from diffusers==0.20
from modules.xadapter.adapter import Adapter_XL
from modules.xadapter.pipeline_sd_xl_adapter import StableDiffusionXLAdapterPipeline
from modules.xadapter.unet_adapter import UNet2DConditionModel as UNet2DConditionModelAdapter
global adapter # pylint: disable=global-statement
if model == 'None':
return
else:
shared.opts.sd_model_refiner = model
if shared.sd_model_type != 'sdxl':
shared.log.error(f'X-Adapter: incorrect base model: {shared.sd_model.__class__.__name__}')
return
if adapter is None:
shared.log.debug('X-Adapter: adapter loading')
adapter = Adapter_XL()
adapter_path = hf.hf_hub_download(repo_id='Lingmin-Ran/X-Adapter', filename='X_Adapter_v1.bin')
adapter_dict = torch.load(adapter_path)
adapter.load_state_dict(adapter_dict)
try:
if adapter is not None:
sd_models.move_model(adapter, devices.device)
except Exception:
pass
if adapter is None:
shared.log.error('X-Adapter: adapter loading failed')
return
sd_models.unload_model_weights(op='model')
sd_models.unload_model_weights(op='refiner')
orig_unetcondmodel = diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
diffusers.models.UNet2DConditionModel = UNet2DConditionModelAdapter # patch diffusers with x-adapter
diffusers.models.unets.unet_2d_condition.UNet2DConditionModel = UNet2DConditionModelAdapter # patch diffusers with x-adapter
sd_models.reload_model_weights(op='model')
sd_models.reload_model_weights(op='refiner')
diffusers.models.unets.unet_2d_condition.UNet2DConditionModel = orig_unetcondmodel # unpatch diffusers
diffusers.models.UNet2DConditionModel = orig_unetcondmodel # unpatch diffusers
if shared.sd_refiner_type != 'sd':
shared.log.error(f'X-Adapter: incorrect adapter model: {shared.sd_model.__class__.__name__}')
return
# backup pipeline and params
orig_pipeline = shared.sd_model
orig_prompt_attention = shared.opts.prompt_attention
pipe = None
try:
shared.log.debug('X-Adapter: creating pipeline')
pipe = StableDiffusionXLAdapterPipeline(
vae=shared.sd_model.vae,
text_encoder=shared.sd_model.text_encoder,
text_encoder_2=shared.sd_model.text_encoder_2,
tokenizer=shared.sd_model.tokenizer,
tokenizer_2=shared.sd_model.tokenizer_2,
unet=shared.sd_model.unet,
scheduler=shared.sd_model.scheduler,
vae_sd1_5=shared.sd_refiner.vae,
text_encoder_sd1_5=shared.sd_refiner.text_encoder,
tokenizer_sd1_5=shared.sd_refiner.tokenizer,
unet_sd1_5=shared.sd_refiner.unet,
scheduler_sd1_5=shared.sd_refiner.scheduler,
adapter=adapter,
)
sd_models.copy_diffuser_options(pipe, shared.sd_model)
sd_models.set_diffuser_options(pipe)
try:
pipe.to(device=devices.device, dtype=devices.dtype)
except Exception:
pass
shared.opts.data['prompt_attention'] = 'Fixed attention'
prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)
negative = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
p.task_args['prompt'] = prompt
p.task_args['negative_prompt'] = negative
p.task_args['prompt_sd1_5'] = prompt
p.task_args['width_sd1_5'] = width
p.task_args['height_sd1_5'] = height
p.task_args['adapter_guidance_start'] = start
p.task_args['adapter_condition_scale'] = scale
p.task_args['fusion_guidance_scale'] = 1.0 # ???
if sampler != 'Default':
pipe.scheduler_sd1_5 = sd_samplers.create_sampler(sampler, shared.sd_refiner)
else:
pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.scheduler_sd1_5 = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config)
pipe.scheduler_sd1_5.config.timestep_spacing = "leading"
shared.log.debug(f'X-Adapter: pipeline={pipe.__class__.__name__} args={p.task_args}')
shared.sd_model = pipe
except Exception as e:
shared.log.error(f'X-Adapter: pipeline creation failed: {e}')
errors.display(e, 'X-Adapter: pipeline creation failed')
shared.sd_model = orig_pipeline
# run pipeline
processed: processing.Processed = processing.process_images(p) # runs processing using main loop
# restore pipeline and params
try:
if adapter is not None:
adapter.to(devices.cpu)
except Exception:
pass
pipe = None
shared.opts.data['prompt_attention'] = orig_prompt_attention
shared.sd_model = orig_pipeline
devices.torch_gc()
return processed
|