Vision / app.py
Staticaliza's picture
Update app.py
baefccb verified
raw
history blame
2.94 kB
# Imports
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import argparse
from decord import VideoReader, cpu
import io
import os
import copy
import requests
import base64
import json
import traceback
import re
import modelscope_studio as mgr
# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")
# Variables
DEFAULT_INPUT = "Describe in one paragraph."
repo = AutoModel.from_pretrained("openbmb/MiniCPM-V-2_6", torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6", trust_remote_code=True)
repo.eval()
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
visibility: hidden
}
'''
# Functions
@spaces.GPU(duration=60)
def generate(image, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
global model, tokenizer
print(image)
image_rgb = Image.open(image).convert("RGB")
print(image_rgb, instruction)
inputs = [{"role": "user", "content": [image_rgb, instruction]}]
parameters = {
"sampling": sampling,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_tokens
}
output = model.chat(image=None, msgs=inputs, tokenizer=tokenizer, **parameters)
return output
def cloud():
print("[CLOUD] | Space maintained.")
# Initialize
with gr.Blocks(css=css) as main:
with gr.Column():
gr.Markdown("🪄 Analyze images and caption them using state-of-the-art openbmb/MiniCPM-V-2_6.")
with gr.Column():
input = gr.Image(label="Image")
instruction = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Instruction")
sampling = gr.Checkbox(value=False, label="Sampling")
temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label="Top P")
top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="Top K")
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty")
max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
submit = gr.Button("▶")
maintain = gr.Button("☁️")
with gr.Column():
output = gr.Textbox(lines=1, value="", label="Output")
submit.click(fn=generate, inputs=[input, instruction, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False)
maintain.click(cloud, inputs=[], outputs=[], queue=False)
main.launch(show_api=True)