import torch import gc import gradio as gr from main import setup, execute_task from arguments import parse_args import os import shutil import glob import time import threading import argparse def list_iter_images(save_dir): # Specify only PNG images image_extension = 'png' # Create a list to store the image file paths image_paths = [] # Use glob to find all PNG image files all_images = glob.glob(os.path.join(save_dir, f'*.{image_extension}')) # Filter out 'best_image.png' image_paths = [img for img in all_images if os.path.basename(img) != 'best_image.png'] return image_paths def clean_dir(save_dir): # Check if the directory exists if os.path.exists(save_dir): # Check if the directory contains any files if len(os.listdir(save_dir)) > 0: # If it contains files, delete all files in the directory for filename in os.listdir(save_dir): file_path = os.path.join(save_dir, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove file or symbolic link elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove directory and its contents except Exception as e: print(f"Failed to delete {file_path}. Reason: {e}") print(f"All files in {save_dir} have been deleted.") else: print(f"{save_dir} exists but is empty.") else: print(f"{save_dir} does not exist.") def start_over(gallery_state): torch.cuda.empty_cache() # Free up cached memory gc.collect() if gallery_state is not None: gallery_state = None return gallery_state, None, None, gr.update(visible=False) def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)): gr.Info(f"Loading {model} model ...") if prompt is None or prompt == "": raise gr.Error("You forgot to provide a prompt !") print(f"LOADED_MODEL SETUP: {loaded_model_setup}") """Clear CUDA memory before starting the training.""" torch.cuda.empty_cache() # Free up cached memory gc.collect() # Set up arguments args = parse_args() args.task = "single" args.prompt = prompt args.model = model args.seed = seed args.n_iters = num_iterations args.lr = learning_rate args.cache_dir = "./HF_model_cache" args.save_dir = "./outputs" args.save_all_images = True if enable_hps is True: args.disable_hps = False args.hps_weighting = hps_w if enable_imagereward is True: args.disable_imagereward = False args.imagereward_weighting = imgrw_w if enable_pickscore is True: args.disable_pickscore = False args.pickscore_weighting = pcks_w if enable_clip is True: args.disable_clip = False args.clip_weighting = clip_w if model == "flux": args.cpu_offloading = True args.enable_multi_apply = True args.multi_step_model = "flux" # Check if args are the same as the loaded_model_setup except for the prompt if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'): previous_args = loaded_model_setup[0] # Exclude 'prompt' from comparison new_args_dict = {k: v for k, v in args.__dict__.items() if k != 'prompt'} prev_args_dict = {k: v for k, v in previous_args.__dict__.items() if k != 'prompt'} if new_args_dict == prev_args_dict: # If the arguments (excluding prompt) are the same, reuse the loaded setup print(f"Arguments (excluding prompt) are the same, reusing loaded setup for {model} model.") # Update the prompt in the loaded_model_setup loaded_model_setup[0].prompt = prompt yield f"{model} model already loaded with the same configuration.", loaded_model_setup # Attempt to set up the model try: # If other args differ, proceed with the setup args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup) new_loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings, pipe] yield f"{model} model loaded successfully!", new_loaded_setup except Exception as e: print(f"Failed to load {model} model: {e}.") yield f"Failed to load {model} model: {e}. You can try again, as it usually finally loads on the second try :)", None def generate_image(setup_args, num_iterations): torch.cuda.empty_cache() # Free up cached memory gc.collect() gr.Info(f"Executing iterations task ...") args = setup_args[0] trainer = setup_args[1] device = setup_args[2] dtype = setup_args[3] shape = setup_args[4] enable_grad = setup_args[5] settings = setup_args[6] print(f"SETTINGS: {settings}") pipe = setup_args[7] save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}" clean_dir(save_dir) try: torch.cuda.empty_cache() # Free up cached memory gc.collect() steps_completed = [] result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None} error_status = {"error_occurred": False} # Shared dictionary to track error status thread_status = {"running": False} # Track whether a thread is already running def progress_callback(step): # Limit redundant prints by checking the step number if not steps_completed or step > steps_completed[-1]: steps_completed.append(step) print(f"Progress: Step {step} completed.") def run_main(): thread_status["running"] = True # Mark thread as running try: execute_task( args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback ) except torch.cuda.OutOfMemoryError as e: print(f"CUDA Out of Memory Error: {e}") error_status["error_occurred"] = True except RuntimeError as e: if 'out of memory' in str(e): print(f"Runtime Error: {e}") error_status["error_occurred"] = True else: raise finally: thread_status["running"] = False # Mark thread as completed if not thread_status["running"]: # Ensure no other thread is running main_thread = threading.Thread(target=run_main) main_thread.start() last_step_yielded = 0 while main_thread.is_alive() and not error_status["error_occurred"]: # Check if new steps have been completed if steps_completed and steps_completed[-1] > last_step_yielded: last_step_yielded = steps_completed[-1] png_number = last_step_yielded - 1 # Get the image for this step image_path = os.path.join(save_dir, f"{png_number}.png") if os.path.exists(image_path): yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None) else: yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None) else: time.sleep(0.1) # Sleep to prevent busy waiting if error_status["error_occurred"]: torch.cuda.empty_cache() # Free up cached memory gc.collect() yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None) else: main_thread.join() # Ensure thread completion final_image_path = os.path.join(save_dir, "best_image.png") if os.path.exists(final_image_path): iter_images = list_iter_images(save_dir) torch.cuda.empty_cache() # Free up cached memory gc.collect() time.sleep(0.5) yield (final_image_path, f"Final image saved at {final_image_path}", iter_images) else: torch.cuda.empty_cache() # Free up cached memory gc.collect() yield (None, "Image generation completed, but no final image was found.", None) torch.cuda.empty_cache() # Free up cached memory gc.collect() except torch.cuda.OutOfMemoryError as e: print(f"Global CUDA Out of Memory Error: {e}") yield (None, f"{e}", None) except RuntimeError as e: if 'out of memory' in str(e): print(f"Runtime Error: {e}") yield (None, f"{e}", None) else: yield (None, f"An error occurred: {str(e)}", None) except Exception as e: print(f"Unexpected Error: {e}") yield (None, f"An unexpected error occurred: {str(e)}", None) def show_gallery_output(gallery_state): if gallery_state is not None: return gr.update(value=gallery_state, visible=True) else: return gr.update(value=None, visible=False) def combined_function(gallery_state, loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)): # Step 1: Start Over gallery_state, output_image, status, iter_gallery_update = start_over(gallery_state) model_status = "" # No model status yet yield gallery_state, output_image, status, iter_gallery_update, loaded_model_setup, model_status # Step 2: Setup the model model_status, new_loaded_model_setup = None, None for model_status, new_loaded_model_setup in setup_model( loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate): yield gallery_state, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status # Step 3: Generate the image output_image, status, gallery_state_update = None, None, None for output_image, status, gallery_state_update in generate_image(new_loaded_model_setup, n_iter): yield gallery_state_update, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status # Step 4: Show the gallery iter_gallery_update = show_gallery_output(gallery_state_update) yield gallery_state_update, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status # Create Gradio interface title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." css=""" #model-status-id{ height: 126px; } #model-status-id .progress-text{ font-size: 10px!important; } #model-status-id .progress-level-inner{ font-size: 8px!important; } """ with gr.Blocks(css=css, analytics_enabled=False) as demo: loaded_model_setup = gr.State() gallery_state = gr.State() with gr.Column(): gr.Markdown(title) gr.Markdown(description) gr.HTML("""
""") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") with gr.Row(): chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo") seed = gr.Number(label="seed", value=0) model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id") with gr.Row(): n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=10, label="Number of Iterations") learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") with gr.Accordion("Advanced Settings", open=True): with gr.Column(): with gr.Row(): enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1) hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3) with gr.Row(): enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1) imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3) with gr.Row(): enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1) pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3) with gr.Row(): enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1) clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3) submit_btn = gr.Button("Submit") gr.Examples( examples = [ "A red dog and a green cat", "A pink elephant and a grey cow", "A toaster riding a bike", "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski", "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions", "An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains" ], inputs = [prompt] ) with gr.Column(): output_image = gr.Image(type="filepath", label="Best Generated Image") status = gr.Textbox(label="Status") iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False) def allow_weighting(weight_type): if weight_type is True: return gr.update(interactive=True) else: return gr.update(interactive=False) enable_hps.change( fn = allow_weighting, inputs = [enable_hps], outputs = [hps_w], queue = False ) enable_imagereward.change( fn = allow_weighting, inputs = [enable_imagereward], outputs = [imgrw_w], queue = False ) enable_pickscore.change( fn = allow_weighting, inputs = [enable_pickscore], outputs = [pcks_w], queue = False ) enable_clip.change( fn = allow_weighting, inputs = [enable_clip], outputs = [clip_w], queue = False ) submit_btn.click( fn = combined_function, inputs = [ gallery_state, loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate ], outputs = [ gallery_state, output_image, status, iter_gallery, loaded_model_setup, model_status # Ensure `model_status` is included in the outputs ] ) """ submit_btn.click( fn = start_over, inputs =[gallery_state], outputs = [gallery_state, output_image, status, iter_gallery] ).then( fn = setup_model, inputs = [loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate], outputs = [model_status, loaded_model_setup] # Load the new setup into the state ).then( fn = generate_image, inputs = [loaded_model_setup, n_iter], outputs = [output_image, status, gallery_state] ).then( fn = show_gallery_output, inputs = [gallery_state], outputs = iter_gallery ) """ # Launch the app demo.queue().launch(show_error=True, show_api=False)