SD35-IP-Adapter / app.py
multimodalart's picture
Update app.py
20c1d49 verified
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
import os
from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
from transformers import AutoProcessor, SiglipVisionModel
from huggingface_hub import hf_hub_download
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model_path = 'stabilityai/stable-diffusion-3.5-large'
image_encoder_path = "google/siglip-so400m-patch14-384"
ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin")
transformer = SD3Transformer2DModel.from_pretrained(
model_path,
subfolder="transformer",
torch_dtype=torch.bfloat16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_path,
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
pipe.init_ipadapter(
ip_adapter_path=ipadapter_path,
image_encoder_path=image_encoder_path,
nb_token=64,
)
def resize_img(image, max_size=1024):
width, height = image.size
scaling_factor = min(max_size / width, max_size / height)
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return image.resize((new_width, new_height), Image.LANCZOS)
@spaces.GPU
def process_image(
image,
prompt,
scale,
seed,
randomize_seed,
width,
height,
progress=gr.Progress(track_tqdm=True),
):
#pipe.to("cuda")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if image is None:
return None, seed
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Resize image
image = resize_img(image)
# Generate the image
result = pipe(
clip_image=image,
prompt=prompt,
ipadapter_scale=scale,
width=width,
height=height,
generator=torch.Generator().manual_seed(seed)
).images[0]
return result, seed
# UI CSS
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
# Create the Gradio interface
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# InstantX's SD3.5 IP Adapter")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil"
)
scale = gr.Slider(
label="Image Scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
)
prompt = gr.Text(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
run_button = gr.Button("Generate", variant="primary")
with gr.Column():
result = gr.Image(label="Result")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
run_button.click(
fn=process_image,
inputs=[
input_image,
prompt,
scale,
seed,
randomize_seed,
width,
height,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()