yellowcandle's picture
try to fix runtime error on HF
d3b8a9b unverified
raw
history blame contribute delete
No virus
4 kB
import spaces
import gradio as gr
import os
import orjson
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForCausalLM, AutoTokenizer
transcribe_model = None
proofread_model = None
@spaces.GPU(duration=60)
def transcribe_audio(audio):
global transcribe_model
if audio is None:
return "Please upload an audio file."
if transcribe_model is None:
return "Please load the transcription model first."
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained(transcribe_model)
pipe = pipeline(
"automatic-speech-recognition",
model=transcribe_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
result = pipe(audio)
return result["text"]
@spaces.GPU(duration=120)
def proofread(text):
global proofread_model
if text is None:
return "Please provide the transcribed text for proofreading."
if proofread_model is None:
return "Please load the proofreading model first."
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
messages = [
{"role": "system", "content": "用繁體中文語體文整理這段文字,在最後加上整段文字的重點。"},
{"role": "user", "content": text},
]
inputs = proofread_model.tokenizer(messages, return_tensors="tf", padding=True)
outputs = proofread_model.generate(**inputs)
proofread_text = proofread_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
return proofread_text
@spaces.GPU(duration=120)
def load_models(transcribe_model_id, proofread_model_id):
global transcribe_model, proofread_model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
transcribe_model = AutoModelForSpeechSeq2Seq.from_pretrained(
transcribe_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
transcribe_model.to(device)
proofread_model = AutoModelForCausalLM.from_pretrained(proofread_model_id)
proofread_model.to(device)
with gr.Blocks() as demo:
gr.Markdown("""
# Audio Transcription and Proofreading
1. Select models for transcription and proofreading and load them
2. Upload an audio file (Wait for the file to be fully loaded first)
3. Transcribe the audio
4. Proofread the transcribed text
""")
with gr.Row():
transcribe_model_dropdown = gr.Dropdown(choices=["openai/whisper-large-v2", "alvanlii/whisper-small-cantonese"], value="alvanlii/whisper-small-cantonese", label="Select Transcription Model")
proofread_model_dropdown = gr.Dropdown(choices=["hfl/llama-3-chinese-8b-instruct-v3"], value="hfl/llama-3-chinese-8b-instruct-v3", label="Select Proofreading Model")
load_button = gr.Button("Load Models")
audio = gr.Audio(sources="upload", type="filepath")
transcribe_button = gr.Button("Transcribe")
transcribed_text = gr.Textbox(label="Transcribed Text")
proofread_button = gr.Button("Proofread")
proofread_output = gr.Textbox(label="Proofread Text")
load_button.click(load_models, inputs=[transcribe_model_dropdown, proofread_model_dropdown])
transcribe_button.click(transcribe_audio, inputs=audio, outputs=transcribed_text)
proofread_button.click(proofread, inputs=transcribed_text, outputs=proofread_output)
transcribed_text.change(proofread, inputs=transcribed_text, outputs=proofread_output)
demo.launch()