meissonic / app.py
multimodalart's picture
Update app.py
5b3714e verified
raw
history blame
No virus
5.42 kB
import os
import sys
sys.path.append("./")
import torch
from torchvision import transforms
from src.transformer import Transformer2DModel
from src.pipeline import Pipeline
from src.scheduler import Scheduler
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from diffusers import VQModel
import gradio as gr
import spaces
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = "MeissonFlow/Meissonic"
model = Transformer2DModel.from_pretrained(model_path, subfolder="transformer")
vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae")
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler")
pipe = Pipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler)
pipe.to(device)
MAX_SEED = 2**32 - 1
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
if randomize_seed or seed == 0:
seed = torch.randint(0, MAX_SEED, (1,)).item()
torch.manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
).images[0]
return image, seed
# Default negative prompt
default_negative_prompt = "worst quality, normal quality, low quality, low res, blurry, distortion, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, bad anatomy, bad proportions, bad quality, deformed, disconnected limbs, out of frame, out of focus, dehydrated, disfigured, extra arms, extra limbs, extra hands, fused fingers, gross proportions, long neck, jpeg, malformed limbs, mutated, mutated hands, mutated limbs, missing arms, missing fingers, picture frame, poorly drawn hands, poorly drawn face, collage, pixel, pixelated, grainy, color aberration, amputee, autograph, bad illustration, beyond the borders, blank background, body out of frame, boring background, branding, cut off, dismembered, disproportioned, distorted, draft, duplicated features, extra fingers, extra legs, fault, flaw, grains, hazy, identifying mark, improper scale, incorrect physiology, incorrect ratio, indistinct, kitsch, low resolution"
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
examples = [
"a beautiful landscape painting with a river and mountains in the background",
"a futuristic cityscape at night with flying cars",
"a portrait of a smiling old man with wrinkles and wisdom in his eyes",
]
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Meissonic Text-to-Image Generator")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
value=default_negative_prompt,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=9.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=48,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=generate_image,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
demo.launch()