|
from fastapi import FastAPI, UploadFile, File |
|
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor |
|
import torch |
|
import tempfile |
|
import os |
|
import time |
|
from pydantic import BaseModel |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
model_name = "openai/whisper-large-v2" |
|
model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
processor = WhisperProcessor.from_pretrained(model_name) |
|
|
|
|
|
forced_decoder_ids = processor.get_decoder_prompt_ids(language="portuguese", task="transcribe") |
|
model.config.forced_decoder_ids = forced_decoder_ids |
|
|
|
|
|
asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
device=device |
|
) |
|
|
|
model_name = 'pierreguillou/bert-base-cased-squad-v1.1-portuguese' |
|
qa_pipeline = pipeline("question-answering", model=model_name) |
|
|
|
|
|
context = r""" |
|
A pandemia de COVID-19, também conhecida como pandemia de coronavírus, é uma pandemia em curso de COVID-19, |
|
uma doença respiratória aguda causada pelo coronavírus da síndrome respiratória aguda grave 2 (SARS-CoV-2). |
|
A doença foi identificada pela primeira vez em Wuhan, na província de Hubei, República Popular da China, |
|
em 1 de dezembro de 2019, mas o primeiro caso foi reportado em 31 de dezembro do mesmo ano. |
|
Acredita-se que o vírus tenha uma origem zoonótica, porque os primeiros casos confirmados |
|
tinham principalmente ligações ao Mercado Atacadista de Frutos do Mar de Huanan, que também vendia animais vivos. |
|
Em 11 de março de 2020, a Organização Mundial da Saúde declarou o surto uma pandemia. Até 8 de fevereiro de 2021, |
|
pelo menos 105 743 102 casos da doença foram confirmados em pelo menos 191 países e territórios, |
|
com cerca de 2 308 943 mortes e 58 851 440 pessoas curadas. |
|
""" |
|
|
|
|
|
class QuestionRequest(BaseModel): |
|
question: str |
|
|
|
|
|
@app.post("/answer/") |
|
async def answer_question(request: QuestionRequest): |
|
try: |
|
|
|
result = qa_pipeline(question=request.question, context=context) |
|
return { |
|
"question": request.question, |
|
"answer": result['answer'], |
|
"score": round(result['score'], 4), |
|
"start": result['start'], |
|
"end": result['end'] |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "Welcome to the FastAPI app on Hugging Face Spaces!"} |
|
|
|
|
|
@app.post("/transcribe/") |
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
start_time = time.time() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: |
|
temp_audio_file.write(await file.read()) |
|
temp_file_path = temp_audio_file.name |
|
|
|
|
|
transcription_start = time.time() |
|
transcription = asr_pipeline(temp_file_path, return_timestamps=True) |
|
transcription_end = time.time() |
|
|
|
|
|
os.remove(temp_file_path) |
|
|
|
|
|
end_time = time.time() |
|
print(f"Time to transcribe audio: {transcription_end - transcription_start:.4f} seconds") |
|
print(f"Total execution time: {end_time - start_time:.4f} seconds") |
|
|
|
return {"transcription": transcription['text']} |
|
|
|
@app.get("/playground/", response_class=HTMLResponse) |
|
def playground(): |
|
html_content = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>Voice Recorder</title> |
|
</head> |
|
<body> |
|
<h1>Record your voice</h1> |
|
<button id="startBtn">Start Recording</button> |
|
<button id="stopBtn" disabled>Stop Recording</button> |
|
<p id="status">Press start to record your voice...</p> |
|
|
|
<audio id="audioPlayback" controls style="display:none;"></audio> |
|
<script> |
|
let mediaRecorder; |
|
let audioChunks = []; |
|
|
|
const startBtn = document.getElementById('startBtn'); |
|
const stopBtn = document.getElementById('stopBtn'); |
|
const status = document.getElementById('status'); |
|
const audioPlayback = document.getElementById('audioPlayback'); |
|
|
|
// Start Recording |
|
startBtn.addEventListener('click', async () => { |
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); |
|
mediaRecorder = new MediaRecorder(stream); |
|
mediaRecorder.start(); |
|
|
|
status.textContent = 'Recording...'; |
|
startBtn.disabled = true; |
|
stopBtn.disabled = false; |
|
|
|
mediaRecorder.ondataavailable = event => { |
|
audioChunks.push(event.data); |
|
}; |
|
}); |
|
|
|
// Stop Recording |
|
stopBtn.addEventListener('click', () => { |
|
mediaRecorder.stop(); |
|
mediaRecorder.onstop = async () => { |
|
status.textContent = 'Recording stopped. Preparing to send...'; |
|
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' }); |
|
const audioUrl = URL.createObjectURL(audioBlob); |
|
audioPlayback.src = audioUrl; |
|
audioPlayback.style.display = 'block'; |
|
audioChunks = []; |
|
|
|
// Send audio blob to FastAPI endpoint |
|
const formData = new FormData(); |
|
formData.append('file', audioBlob, 'recording.wav'); |
|
|
|
const response = await fetch('/transcribe/', { |
|
method: 'POST', |
|
body: formData, |
|
}); |
|
|
|
const result = await response.json(); |
|
status.textContent = 'Transcription: ' + result.transcription; |
|
}; |
|
|
|
startBtn.disabled = false; |
|
stopBtn.disabled = true; |
|
}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_content) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|