Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import io | |
import base64 | |
import requests | |
import json | |
images_in_gallery = [] | |
rewards_in_gallery = [] | |
def generate_images( | |
prompt, magic_words, num, height, width, num_inference_steps, guidance_scale | |
): | |
global images_in_gallery, rewards_in_gallery | |
if magic_words is not None: | |
prompt += ", ".join(magic_words) | |
# post 请求发送到服务器 | |
# 定义请求的 URL 和数据 | |
url = 'https://tianqi.aminer.cn/image_reward_hf/generate_image' | |
data = {'prompt': prompt, | |
'height': height, | |
'width':width, | |
'num_inference_steps':num_inference_steps, | |
'guidance_scale':guidance_scale, | |
'num':num | |
} | |
headers = {'Content-Type': 'application/json'} | |
# 发送 POST 请求 | |
data = json.dumps(data) | |
response = requests.post(url, data=data, headers=headers) | |
image_ls = response.json()['image_list'] | |
images_in_gallery = [] | |
for base_image in image_ls: | |
image_bytes = base64.b64decode(base_image) | |
# 创建 BytesIO 对象并读取图像字节流 | |
image_stream = io.BytesIO(image_bytes) | |
# 打开图像 | |
image = Image.open(image_stream) | |
images_in_gallery.append(image) | |
rewards_in_gallery = [None] * len(images_in_gallery) | |
return list(zip(images_in_gallery, rewards_in_gallery)) | |
def score_and_rank(prompt): | |
global rewards_in_gallery, images_in_gallery | |
num_not_scored = rewards_in_gallery.count(None) | |
if num_not_scored > 0: | |
images_to_score = images_in_gallery[-num_not_scored:] | |
image_ls = [] | |
for image in images_to_score: | |
image_bytes = io.BytesIO() | |
image.save(image_bytes, format='JPEG') | |
image_bytes.seek(0) | |
# 将字节流转换为 Base64 编码 | |
base64_image = base64.b64encode(image_bytes.read()).decode('utf-8') | |
image_ls.append(base64_image) | |
with torch.no_grad(): | |
# post 请求发送到服务器 | |
url = 'https://tianqi.aminer.cn/image_reward_hf/score_and_rank' | |
data = {'images_to_score': image_ls, 'prompt':prompt} | |
data = json.dumps(data) | |
headers = {'Content-Type': 'application/json'} | |
# 发送 POST 请求 | |
response = requests.post(url, data=data, headers=headers) | |
rewards = response.json()['rewards'] | |
if not isinstance(rewards, list): | |
rewards = [rewards] | |
rewards_in_gallery = rewards_in_gallery[:-num_not_scored] + rewards | |
outputs = sorted( | |
zip(images_in_gallery, rewards_in_gallery), key=lambda x: x[1], reverse=True | |
) | |
images_in_gallery = [image for image, _ in outputs] | |
rewards_in_gallery = [reward for _, reward in outputs] | |
return outputs, [ | |
[idx + 1, reward] for idx, reward in enumerate(rewards_in_gallery) | |
] | |
else: | |
return list(zip(images_in_gallery, rewards_in_gallery)), [ | |
[idx + 1, reward] for idx, reward in enumerate(rewards_in_gallery) | |
] | |
def upload_images_to_gallery(uploaded_image_files): | |
global images_in_gallery, rewards_in_gallery | |
uploaded_image_file_paths = [file.name for file in uploaded_image_files] | |
uploaded_images = [Image.open(path) for path in uploaded_image_file_paths] | |
for path in uploaded_image_file_paths: | |
os.remove(path) | |
images_in_gallery = images_in_gallery + uploaded_images | |
rewards_in_gallery = rewards_in_gallery + [None] * len(uploaded_images) | |
return list(zip(images_in_gallery, rewards_in_gallery)) | |
def clear_images(): | |
global images_in_gallery, rewards_in_gallery | |
images_in_gallery = [] | |
rewards_in_gallery = [] | |
return None | |
if __name__ == "__main__": | |
# UI | |
with gr.Blocks( | |
theme=gr.themes.Monochrome(), | |
css=r".caption-label { color: black; }", | |
) as demo: | |
gr.HTML( | |
""" | |
<h1 align="center">ImageReward Demo</h1> | |
<p align="center"><a href="https://github.com/THUDM/ImageReward">GitHub Repo</a> • 🤗 <a href="https://huggingface.co/THUDM/ImageReward" target="_blank">HF Repo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a> • 📃 <a href="https://arxiv.org/abs/2304.05977" target="_blank">Paper</a><br></p> | |
<br> | |
<p dir="auto">ImageReward is the first general-purpose text-to-image <strong>human preference RM</strong>, which is trained on in total <strong>137k pairs of expert comparisons</strong>!</p> | |
<p dir="auto">The calculation of ImageRewards is based on <strong>both the prompt and images</strong>.</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
""" | |
<p dir="auto">Try ImageReward with only 2 steps:</p> | |
<ol dir="auto"> | |
<li>Click the <strong>"Generate"</strong> button <strong>in the middle of the bottom</strong>.</li> | |
<li>Click the <strong>"Score&Rank"</strong> button <strong>below the gallery</strong>.</li> | |
</ol> | |
<p dir="auto">Finally, just check ImageRewards <strong>along with images or on the right of the gallery</strong>.</p> | |
<br> | |
<p dir="auto">This demo uses <code>runwayml/stable-diffusion-v1-5</code> as image generation model.</p> | |
""" | |
) | |
with gr.Column(): | |
gr.HTML( | |
""" | |
<p dir="auto">Besides generating images, you can also <strong>upload</strong> images to score:</p> | |
<ol dir="auto"> | |
<li>Upload images <strong>in the bottom right corner</strong>.</li> | |
<li>Change the <strong>"Prompt"</strong> to correspond to the images.</li> | |
<li>Click the <strong>"Score&Rank"</strong> button <strong>below the gallery</strong>.</li> | |
</ol> | |
<br> | |
<p dir="auto">For more details about using ImageReward in your own program, check <a href="https://github.com/THUDM/ImageReward">the README.md in our Github Repo</a>.</p> | |
""" | |
) | |
with gr.Row(elem_id="outputs_row"): | |
with gr.Column(elem_id="gallery_column", scale=4): | |
gallery = gr.Gallery( | |
label="Images (scored ones sorted)", | |
show_label=False, | |
elem_id="gallery", | |
).style(columns=4, object_fit="contain", full_width=True) | |
with gr.Column(elem_id="rewards_column"): | |
rewards = gr.Matrix( | |
value=[[None, None]], | |
headers=["Rank", "ImageReward"], | |
datatype="number", | |
) | |
with gr.Row(): | |
score_and_rank_button = gr.Button("Score&Rank") | |
clear_button = gr.Button("Clear Gallery") | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="A painting of an ocean with clouds and birds, day time, low depth field effect, oil painting, impressionism", | |
) | |
examples = [ | |
"A painting of an ocean with clouds and birds, day time, low depth field effect, oil painting, impressionism", | |
"A painting of a girl walking in a hallway and suddenly finds a giant sunflower on the floor blocking her way", | |
"Coronation of the sun emperor, digital art, illustration,4k resolution,intricate extremely detailed, depth,vivid colors", | |
"Symmetry!! Product render poster vivid colors divine proportion owl,glowing fog intricate,elegant, highly detailed", | |
"A unicorn in a clearing.it has a single shining horn. volumetric light.by emmanuel shiu, harry potter, eragon", | |
"Highly detailed portrait of a woman with long hairs,stephen bliss. unreal engine, fantasy art by greg rutkowski", | |
"Sculpture made of flame,portrait, female,future, torch,fire,harper's bazaar,vogue, fashion magazine, intricate", | |
] | |
prompt_examples = gr.Examples( | |
examples=examples, | |
label="Prompt Examples", | |
inputs=[prompt], | |
elem_id="prompt_examples", | |
) | |
with gr.Column(): | |
choices = [ | |
"HDR, UHD, 4K, 8K, 64K", | |
"highly detailed", | |
"studio lighting", | |
"professional", | |
"trending on artstation", | |
"unreal engine", | |
"vivid colors", | |
] | |
magic_words = gr.CheckboxGroup( | |
choices=choices, | |
value=choices, | |
type="value", | |
label="Magic Words to Append to Prompt", | |
) | |
num = gr.Slider(1, 16, step=1, label="Number of images", value=8) | |
height = gr.Slider(256, 2048, step=256, label="Height", value=512) | |
width = gr.Slider(256, 2048, step=256, label="Width", value=512) | |
num_inference_steps = gr.Slider( | |
0, 200, step=10, label="Number of inference steps", value=50 | |
) | |
guidance_scale = gr.Slider( | |
0, 25, step=0.1, label="Guidance scale", value=7.5 | |
) | |
generate_button = gr.Button("Generate") | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
- To clear all uploaded images, click the **"Clear Gallery"** button above. | |
- To clear the upload list and add additional images, click the **`x` in the upper right corner of the uploading window**. | |
- Additional images will be appended to the gallery, instead of replacing the existing ones. | |
""" | |
) | |
uploaded_image_files = gr.File( | |
file_count="multiple", | |
file_types=["image"], | |
type="file", | |
label="Upload Images", | |
show_label=True, | |
) | |
generate_button.click( | |
generate_images, | |
[ | |
prompt, | |
magic_words, | |
num, | |
height, | |
width, | |
num_inference_steps, | |
guidance_scale, | |
], | |
[gallery], | |
) | |
score_and_rank_button.click(score_and_rank, [prompt], [gallery, rewards]) | |
uploaded_image_files.upload( | |
upload_images_to_gallery, [uploaded_image_files], [gallery] | |
) | |
clear_button.click(clear_images, None, [gallery]) | |
demo.launch() | |