Spaces:
Paused
Paused
import os | |
import datetime | |
import json | |
import base64 | |
from PIL import Image | |
import gradio as gr | |
import hashlib | |
import requests | |
from utils import build_logger | |
import io | |
LOGDIR = "log" | |
logger = build_logger("otter", LOGDIR) | |
# no_change_btn = gr.Button.update() | |
# enable_btn = gr.Button.update(interactive=True) | |
# disable_btn = gr.Button.update(interactive=False) | |
def decode_image(encoded_image: str) -> Image: | |
decoded_bytes = base64.b64decode(encoded_image.encode("utf-8")) | |
buffer = io.BytesIO(decoded_bytes) | |
image = Image.open(buffer) | |
return image | |
def encode_image(image: Image.Image, format: str = "PNG") -> str: | |
with io.BytesIO() as buffer: | |
image.save(buffer, format=format) | |
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
return encoded_image | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
return name | |
def get_conv_image_dir(): | |
name = os.path.join(LOGDIR, "images") | |
os.makedirs(name, exist_ok=True) | |
return name | |
def get_image_name(image, image_dir=None): | |
buffer = io.BytesIO() | |
image.save(buffer, format="PNG") | |
image_bytes = buffer.getvalue() | |
md5 = hashlib.md5(image_bytes).hexdigest() | |
if image_dir is not None: | |
image_name = os.path.join(image_dir, md5 + ".png") | |
else: | |
image_name = md5 + ".png" | |
return image_name | |
def resize_image(image, max_size): | |
width, height = image.size | |
aspect_ratio = float(width) / float(height) | |
if width > height: | |
new_width = max_size | |
new_height = int(new_width / aspect_ratio) | |
else: | |
new_height = max_size | |
new_width = int(new_height * aspect_ratio) | |
resized_image = image.resize((new_width, new_height)) | |
return resized_image | |
def http_bot(image_input, text_input, request: gr.Request): | |
logger.info(f"http_bot. ip: {request.client.host}") | |
print(f"Prompt request: {text_input}") | |
base64_image_str = encode_image(image_input) | |
payload = { | |
"content": [ | |
{ | |
"prompt": text_input, | |
"image": base64_image_str, | |
} | |
], | |
"token": "sk-OtterHD", | |
} | |
print( | |
"request: ", | |
{ | |
"prompt": text_input, | |
"image": base64_image_str[:10], | |
}, | |
) | |
url = "http://10.128.0.40:8890/app/otter" | |
headers = {"Content-Type": "application/json"} | |
response = requests.post(url, headers=headers, data=json.dumps(payload)) | |
results = response.json() | |
print("response: ", {"result": results["result"]}) | |
return results["result"] | |
title = """ | |
# OTTER-HD: A High-Resolution Multi-modality Model | |
[[Otter Codebase]](https://github.com/Luodian/Otter) [[Paper]]() [[Checkpoints & Benchmarks]](https://huggingface.co/Otter-AI) | |
""" | |
css = """ | |
#mkd { | |
height: 1000px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
if __name__ == "__main__": | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(title) | |
dialog_state = gr.State() | |
input_state = gr.State() | |
with gr.Tab("Ask a Question"): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
image_input = gr.Image(label="Upload a High-Res Image", type="pil") | |
with gr.Column(scale=1): | |
vqa_output = gr.Textbox(label="Output") | |
text_input = gr.Textbox(label="Ask a Question") | |
vqa_btn = gr.Button("Send It") | |
gr.Examples( | |
[ | |
[ | |
"./assets/IMG_00095.png", | |
"How many camels are inside this image?", | |
], | |
[ | |
"./assets/IMG_00095.png", | |
"How many people are inside this image?", | |
], | |
[ | |
"./assets/IMG_00012.png", | |
"How many apples are there?", | |
], | |
# ["./assets/./IMG_00012.png", "How many apples are there? Count them row by row."], | |
[ | |
"./assets/IMG_00080.png", | |
"What is this and where is it from?", | |
], | |
[ | |
"./assets/IMG_00094.png", | |
"What's important on this website?", | |
], | |
], | |
inputs=[image_input, text_input], | |
outputs=[vqa_output], | |
fn=http_bot, | |
label="Click on any Examples below👇", | |
) | |
vqa_btn.click(fn=http_bot, inputs=[image_input, text_input], outputs=vqa_output) | |
demo.launch() |