import gradio as gr import modelscope_studio as mgr import librosa from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration from argparse import ArgumentParser DEFAULT_CKPT_PATH = 'Qwen/Qwen2-Audio-7B-Instruct' def _get_args(): parser = ArgumentParser() parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, help="Checkpoint name or path, default to %(default)r") parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") parser.add_argument("--inbrowser", action="store_true", default=False, help="Automatically launch the interface in a new tab on the default browser.") parser.add_argument("--server-port", type=int, default=8000, help="Demo server port.") parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Demo server name.") args = parser.parse_args() return args def add_text(chatbot, task_history, input): text_content = input.text content = [] if len(input.files) > 0: for i in input.files: content.append({'type': 'audio', 'audio_url': i.path}) if text_content: content.append({'type': 'text', 'text': text_content}) task_history.append({"role": "user", "content": content}) chatbot.append([{ "text": input.text, "files": input.files, }, None]) return chatbot, task_history, None def add_file(chatbot, task_history, audio_file): """Add audio file to the chat history.""" task_history.append({"role": "user", "content": [{"audio": audio_file.name}]}) chatbot.append((f"[Audio file: {audio_file.name}]", None)) return chatbot, task_history def reset_user_input(): """Reset the user input field.""" return gr.Textbox.update(value='') def reset_state(task_history): """Reset the chat history.""" return [], [] def regenerate(chatbot, task_history): """Regenerate the last bot response.""" if task_history and task_history[-1]['role'] == 'assistant': task_history.pop() chatbot.pop() if task_history: chatbot, task_history = predict(chatbot, task_history) return chatbot, task_history def predict(chatbot, task_history): """Generate a response from the model.""" print(f"{task_history=}") print(f"{chatbot=}") text = processor.apply_chat_template(task_history, add_generation_prompt=True, tokenize=False) audios = [] for message in task_history: if isinstance(message["content"], list): for ele in message["content"]: if ele["type"] == "audio": audios.append( librosa.load(ele['audio_url'], sr=processor.feature_extractor.sampling_rate)[0] ) if len(audios)==0: audios=None print(f"{text=}") print(f"{audios=}") inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True) if not _get_args().cpu_only: inputs["input_ids"] = inputs.input_ids.to("cuda") generate_ids = model.generate(**inputs, max_length=256) generate_ids = generate_ids[:, inputs.input_ids.size(1):] response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] print(f"{response=}") task_history.append({'role': 'assistant', 'content': response}) chatbot.append((None, response)) # Add the response to chatbot return chatbot, task_history def _launch_demo(args): with gr.Blocks() as demo: gr.Markdown( """

""") gr.Markdown("""

Qwen2-Audio-Instruct Bot
""") gr.Markdown( """\
This WebUI is based on Qwen2-Audio-Instruct, developed by Alibaba Cloud. \ (本WebUI基于Qwen2-Audio-Instruct打造,实现聊天机器人功能。)
""") gr.Markdown("""\
Qwen2-Audio 🤖 | 🤗  | Qwen2-Audio-Instruct 🤖 | 🤗  |  Github
""") chatbot = mgr.Chatbot(label='Qwen2-Audio-7B-Instruct', elem_classes="control-height", height=750) user_input = mgr.MultimodalInput( interactive=True, sources=['microphone', 'upload'], submit_button_props=dict(value="🚀 Submit (发送)"), upload_button_props=dict(value="📁 Upload (上传文件)", show_progress=True), ) task_history = gr.State([]) with gr.Row(): empty_bin = gr.Button("🧹 Clear History (清除历史)") regen_btn = gr.Button("🤔️ Regenerate (重试)") user_input.submit(fn=add_text, inputs=[chatbot, task_history, user_input], outputs=[chatbot, task_history, user_input]).then( predict, [chatbot, task_history], [chatbot, task_history], show_progress=True ) empty_bin.click(reset_state, outputs=[chatbot, task_history], show_progress=True) regen_btn.click(regenerate, [chatbot, task_history], [chatbot, task_history], show_progress=True) demo.queue().launch( share=True, inbrowser=args.inbrowser, server_port=args.server_port, server_name=args.server_name, ) if __name__ == "__main__": args = _get_args() if args.cpu_only: device_map = "cpu" else: device_map = "auto" model = Qwen2AudioForConditionalGeneration.from_pretrained( args.checkpoint_path, torch_dtype="auto", device_map=device_map, resume_download=True, ).eval() model.generation_config.max_new_tokens = 2048 # For chat. print("generation_config", model.generation_config) processor = AutoProcessor.from_pretrained(args.checkpoint_path, resume_download=True) _launch_demo(args)