ReNO / app.py
fffiloni's picture
refactoring for Flux
dd8f929
raw
history blame
11.9 kB
import torch
import gradio as gr
from main import setup, execute_task
from arguments import parse_args
import os
import shutil
import glob
import time
import threading
import argparse
def list_iter_images(save_dir):
# Specify the image extensions you want to search for
image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] # Add more if needed
# Create a list to store the image file paths
image_paths = []
# Iterate through the specified image extensions and get the file paths
for ext in image_extensions:
# Use glob to find all image files with the given extension
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}')))
# Now image_paths contains the list of all image file paths
#print(image_paths)
return image_paths
def clean_dir(save_dir):
# Check if the directory exists
if os.path.exists(save_dir):
# Check if the directory contains any files
if len(os.listdir(save_dir)) > 0:
# If it contains files, delete all files in the directory
for filename in os.listdir(save_dir):
file_path = os.path.join(save_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or symbolic link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory and its contents
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
print(f"All files in {save_dir} have been deleted.")
else:
print(f"{save_dir} exists but is empty.")
else:
print(f"{save_dir} does not exist.")
def start_over(gallery_state, loaded_model_setup):
torch.cuda.empty_cache() # Free up cached memory
if gallery_state is not None:
gallery_state = None
if loaded_model_setup is not None:
loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state
return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
def setup_model(prompt, model, seed, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
if prompt is None:
raise gr.Error("You forgot the prompt !")
"""Clear CUDA memory before starting the training."""
torch.cuda.empty_cache() # Free up cached memory
# Set up arguments
args = parse_args()
args.task = "single"
args.prompt = prompt
args.model = model
args.seed = seed
args.n_iters = num_iterations
args.lr = learning_rate
args.cache_dir = "./HF_model_cache"
args.save_dir = "./outputs"
args.save_all_images = True
args.hps_weighting = hps_w
args.imagereward_weighting = imgrw_w
args.pickscore_weighting = pcks_w
args.clip_weighting = clip_w
if model == "flux":
args.cpu_offloading = True
args.enable_multi_apply= True
args.multi_step_model = "flux"
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
return None, loaded_setup
def generate_image(setup_args, num_iterations):
torch.cuda.empty_cache() # Free up cached memory
args = setup_args[0]
trainer = setup_args[1]
device = setup_args[2]
dtype = setup_args[3]
shape = setup_args[4]
enable_grad = setup_args[5]
multi_apply_fn = setup_args[6]
settings = setup_args[7]
print(f"SETTINGS: {settings}")
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
clean_dir(save_dir)
try:
torch.cuda.empty_cache() # Free up cached memory
steps_completed = []
result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
error_status = {"error_occurred": False} # Shared dictionary to track error status
thread_status = {"running": False} # Track whether a thread is already running
def progress_callback(step):
# Limit redundant prints by checking the step number
if not steps_completed or step > steps_completed[-1]:
steps_completed.append(step)
print(f"Progress: Step {step} completed.")
def run_main():
thread_status["running"] = True # Mark thread as running
try:
execute_task(
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback
)
except torch.cuda.OutOfMemoryError as e:
print(f"CUDA Out of Memory Error: {e}")
error_status["error_occurred"] = True
except RuntimeError as e:
if 'out of memory' in str(e):
print(f"Runtime Error: {e}")
error_status["error_occurred"] = True
else:
raise
finally:
thread_status["running"] = False # Mark thread as completed
if not thread_status["running"]: # Ensure no other thread is running
main_thread = threading.Thread(target=run_main)
main_thread.start()
last_step_yielded = 0
while main_thread.is_alive() and not error_status["error_occurred"]:
# Check if new steps have been completed
if steps_completed and steps_completed[-1] > last_step_yielded:
last_step_yielded = steps_completed[-1]
png_number = last_step_yielded - 1
# Get the image for this step
image_path = os.path.join(save_dir, f"{png_number}.png")
if os.path.exists(image_path):
yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None)
else:
yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None)
else:
time.sleep(0.1) # Sleep to prevent busy waiting
if error_status["error_occurred"]:
torch.cuda.empty_cache() # Free up cached memory
yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
else:
main_thread.join() # Ensure thread completion
final_image_path = os.path.join(save_dir, "best_image.png")
if os.path.exists(final_image_path):
iter_images = list_iter_images(save_dir)
torch.cuda.empty_cache() # Free up cached memory
time.sleep(0.5)
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
else:
torch.cuda.empty_cache() # Free up cached memory
yield (None, "Image generation completed, but no final image was found.", None)
torch.cuda.empty_cache() # Free up cached memory
except torch.cuda.OutOfMemoryError as e:
print(f"Global CUDA Out of Memory Error: {e}")
yield (None, "CUDA out of memory.", None)
except RuntimeError as e:
if 'out of memory' in str(e):
print(f"Runtime Error: {e}")
yield (None, "CUDA out of memory.", None)
else:
yield (None, f"An error occurred: {str(e)}", None)
except Exception as e:
print(f"Unexpected Error: {e}")
yield (None, f"An unexpected error occurred: {str(e)}", None)
def show_gallery_output(gallery_state):
if gallery_state is not None:
return gr.update(value=gallery_state, visible=True)
else:
return gr.update(value=None, visible=False)
# Create Gradio interface
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
with gr.Blocks(analytics_enabled=False) as demo:
loaded_model_setup = gr.State()
gallery_state = gr.State()
with gr.Column():
gr.Markdown(title)
gr.Markdown(description)
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://github.com/ExplainableML/ReNO'>
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href='https://arxiv.org/abs/2406.04312v1'>
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
with gr.Row():
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
seed = gr.Number(label="seed", value=0)
model_status = gr.Textbox(label="model status", visible=False)
with gr.Row():
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
with gr.Accordion("Advanced Settings", open=False):
with gr.Column():
hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0)
imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0)
pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05)
clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01)
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
"A red dog and a green cat",
"A pink elephant and a grey cow",
"A toaster riding a bike",
"Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski",
"A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
"An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains"
],
inputs = [prompt]
)
with gr.Column():
output_image = gr.Image(type="filepath", label="Best Generated Image")
status = gr.Textbox(label="Status")
iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
submit_btn.click(
fn = start_over,
inputs =[gallery_state, loaded_model_setup], # Reset loaded model setup as well
outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset
).then(
fn = setup_model,
inputs = [prompt, chosen_model, seed, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
outputs = [output_image, loaded_model_setup] # Load the new setup into the state
).then(
fn = generate_image,
inputs = [loaded_model_setup, n_iter],
outputs = [output_image, status, gallery_state]
).then(
fn = show_gallery_output,
inputs = [gallery_state],
outputs = iter_gallery
)
# Launch the app
demo.queue().launch(show_error=True, show_api=False)