import gradio as gr from main import main from arguments import parse_args import os def generate_image(prompt, model, num_iterations, learning_rate, progress = gr.Progress(track_tqdm=True)): # Set up arguments args = parse_args() args.task = "single" args.prompt = prompt args.model = model args.n_iters = num_iterations args.lr = learning_rate args.cache_dir = "./HF_model_cache" args.save_dir = "./outputs" args.save_all_images = True try: # Run the main function main(args) settings = ( f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}" f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}" f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}" f"_reg{args.reg_weight if args.enable_reg else '0'}" f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}" f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}" f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}" f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}" f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}" ) save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}" # Return the path to the generated image image_path = f"{save_dir}/best_image.png" if os.path.exists(image_path): return image_path, f"Image generated successfully and saved at {image_path}" else: return None, "Image generation completed, but the file was not found." except Exception as e: return None, f"An error occurred: {str(e)}" # Create Gradio interface title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." with gr.Blocks() as demo: with gr.Column(): gr.Markdown(title) gr.Markdown(description) gr.HTML("""
""") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model") with gr.Row(): n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations") learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") submit_btn = gr.Button("Submit") gr.Examples( examples = [ "A minimalist logo design of a reindeer, fully rendered. The reindeer features distinct, complete shapes using bold and flat colors. The design emphasizes simplicity and clarity, suitable for logo use with a sharp outline and white background.", "A blue scooter is parked near a curb in front of a green vintage car", "A impressionistic oil painting: a lone figure walking on a misty beach, a weathered lighthouse on a cliff, seagulls above crashing waves", "A bird with 8 legs", "An orange chair to the right of a black airplane", "A pink elephant and a grey cow", ], inputs = [prompt] ) with gr.Column(): output_image = gr.Image(type="filepath", label="Generated Image") status = gr.Textbox(label="Status") submit_btn.click( fn = generate_image, inputs = [prompt, chosen_model, n_iter, learning_rate], outputs = [output_image, status] ) # Launch the app demo.queue().launch(show_error=True, show_api=False)