from huggingface_hub import hf_hub_download hf_hub_download(repo_id="InstantX/InstantIR", filename="models/adapter.pt", local_dir=".") hf_hub_download(repo_id="InstantX/InstantIR", filename="models/aggregator.pt", local_dir=".") hf_hub_download(repo_id="InstantX/InstantIR", filename="models/previewer_lora_weights.bin", local_dir=".") import torch from PIL import Image from diffusers import DDPMScheduler from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler from module.ip_adapter.utils import load_adapter_to_pipe from pipelines.sdxl_instantir import InstantIRPipeline def resize_img(input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): w, h = input_image.size if size is not None: w_resize_new, h_resize_new = size else: # ratio = min_side / min(h, w) # w, h = round(ratio*w), round(ratio*h) ratio = max_side / max(h, w) input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number input_image = input_image.resize([w_resize_new, h_resize_new], mode) if pad_to_max_side: res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 offset_x = (max_side - w_resize_new) // 2 offset_y = (max_side - h_resize_new) // 2 res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) input_image = Image.fromarray(res) return input_image # prepare models under ./models instantir_path = f'./models' # load pretrained models pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) # load adapter load_adapter_to_pipe( pipe, f"{instantir_path}/adapter.pt", image_encoder_or_path = 'facebook/dinov2-large', ) # load previewer lora pipe.prepare_previewers(instantir_path) pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) # load aggregator weights pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") pipe.aggregator.load_state_dict(pretrained_state_dict) # send to GPU and fp16 pipe.to(device='cuda', dtype=torch.float16) pipe.aggregator.to(device='cuda', dtype=torch.float16) PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \ ultra HD, extreme meticulous detailing, skin pore detailing, \ hyper sharpness, perfect without deformations, \ taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. " NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \ sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \ dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \ watermark, signature, jpeg artifacts, deformed, lowres" def infer(prompt, input_image, steps=30, cfg_scale=7.0, guidance_end=1.0, creative_restoration=False, seed=3407, height=1024, width=1024): # load a broken image low_quality_image = Image.open(input_image).convert("RGB") lq = [resize_img(low_quality_image, size=(width, height))] generator = torch.Generator(device=device).manual_seed(seed) timesteps = [ i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps) ] timesteps = timesteps[::-1] prompt = PROMPT if len(prompt)==0 else prompt neg_prompt = NEG_PROMPT # InstantIR restoration image = pipe( prompt=[prompt]*len(lq), image=lq, num_inference_steps=steps, generator=generator, timesteps=timesteps, negative_prompt=[neg_prompt]*len(lq), guidance_scale=cfg_scale, previewer_scheduler=lcm_scheduler, ).images[0] return image import gradio as gr with gr.Blocks() as demo: with gr.Column(): with gr.Row(): with gr.Column(): lq_img = gr.Image(label="Low-quality image", type="filepath") with gr.Group(): prompt = gr.Textbox(label="Prompt", value="") submit_btn = gr.Button("InstantIR magic!") output_img = gr.Image(label="InstantIR restored") submit_btn.click( fn=infer, inputs=[prompt, lq_img], outputs=[output_img] ) demo.launch(show_error=True)