Keyven's picture
whisper integration
2449b43
raw
history blame
7.51 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import re
import copy
import secrets
from pathlib import Path
import os
os.system("pip install git+https://github.com/openai/whisper.git")
import whisper
model_whisper = whisper.load_model("small")
# Constants
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="auto", trust_remote_code=True).eval()
def format_text(text):
"""Format text for rendering in the chat UI."""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def transcribe_audio(audio):
audio = whisper.load_audio(audio)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model_whisper.device)
_, probs = model_whisper.detect_language(mel)
options = whisper.DecodingOptions(fp16 = False)
result = whisper.decode(model_whisper, mel, options)
return result.text
def get_chat_response(chatbot, task_history):
global model, tokenizer
chat_query = chatbot[-1][0]
query = task_history[-1][0]
history_cp = copy.deepcopy(task_history)
full_response = ""
history_filter = []
pic_idx = 1
pre = ""
for i, (q, a) in enumerate(history_cp):
if isinstance(q, (tuple, list)):
q = f'Picture {pic_idx}: <img>{q[0]}</img>'
pre += q + '\n'
pic_idx += 1
else:
pre += q
history_filter.append((pre, a))
pre = ""
history, message = history_filter[:-1], history_filter[-1][0]
response, history = model.chat(tokenizer, message, history=history)
image = tokenizer.draw_bbox_on_latest_picture(response, history)
if image is not None:
temp_dir = secrets.token_hex(20)
temp_dir = Path("/tmp") / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = f"tmp{secrets.token_hex(5)}.jpg"
filename = temp_dir / name
image.save(str(filename))
chatbot[-1] = (format_text(chat_query), (str(filename),)) # Hier verwenden wir format_text statt _parse_text
chat_response = response.replace("<ref>", "")
chat_response = chat_response.replace(r"</ref>", "")
chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
if chat_response != "":
chatbot.append((None, chat_response))
else:
chatbot[-1] = (format_text(chat_query), response)
full_response = format_text(response)
task_history[-1] = (query, full_response)
return chatbot
def handle_text_input(history, task_history, text):
"""Handle text input from the user."""
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(format_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def handle_file_upload(history, task_history, file):
"""Handle file upload from the user."""
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def clear_input():
"""Clear the user input."""
return gr.update(value="")
def clear_history(task_history):
"""Clear the chat history."""
task_history.clear()
return []
def handle_regeneration(chatbot, task_history):
"""Handle the regeneration of the last response."""
print("Regenerate clicked")
print("Before:", task_history, chatbot)
if not task_history:
return chatbot
item = task_history[-1]
if item[1] is None:
return chatbot
task_history[-1] = (item[0], None)
chatbot_item = chatbot.pop(-1)
if chatbot_item[0] is None:
chatbot[-1] = (chatbot[-1][0], None)
else:
chatbot.append((chatbot_item[0], None))
print("After:", task_history, chatbot)
return get_chat_response(chatbot, task_history)
with gr.Blocks(theme='gradio/soft') as demo:
audio = gr.Audio(
label="Input Audio",
show_label=False,
source="microphone",
type="filepath"
)
gr.Markdown("# Qwen-VL Multimodal-Vision-Insight")
gr.Markdown(
"## Developed by Keyvan Hardani (Keyvven on [Twitter](https://twitter.com/Keyvven))\n"
"Special thanks to [@Artificialguybr](https://twitter.com/artificialguybr) for the inspiration from his code.\n"
"### Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud\n"
)
chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=520)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
with gr.Column(width=4):
upload_btn = gr.UploadButton("πŸ“ Upload", file_types=["image"], elem_classes="control-width")
with gr.Column(width=2):
submit_btn = gr.Button("πŸš€ Submit", elem_classes="control-width")
with gr.Column(width=2):
regen_btn = gr.Button("πŸ€”οΈ Regenerate", elem_classes="control-width")
with gr.Column(width=2):
clear_btn = gr.Button("🧹 Clear History", elem_classes="control-width")
gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.")
submit_btn.click(handle_text_input, [chatbot, task_history, query], [chatbot, task_history]).then(
get_chat_response, [chatbot, task_history], [chatbot], show_progress=True
)
submit_btn.click(clear_input, [], [query])
clear_btn.click(clear_history, [task_history], [chatbot], show_progress=True)
regen_btn.click(handle_regeneration, [chatbot, task_history], [chatbot], show_progress=True)
upload_btn.upload(handle_file_upload, [chatbot, task_history, upload_btn], [chatbot, task_history], show_progress=True)
audio.on_change(transcribe_audio, inputs=[audio], outputs=[query])
demo.launch()