Spaces:
Runtime error
Runtime error
anyantudre
commited on
Commit
•
5c30e04
1
Parent(s):
1ffb58d
Rename speech_to_text.py to goai_stt.py
Browse files- goai_stt.py +49 -0
- speech_to_text.py +0 -46
goai_stt.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
from transformers import set_seed, Wav2Vec2ForCTC, AutoProcessor
|
5 |
+
|
6 |
+
|
7 |
+
def goai_stt(fichier, device):
|
8 |
+
"""
|
9 |
+
Transcrire un fichier audio donné.
|
10 |
+
|
11 |
+
Paramètres
|
12 |
+
----------
|
13 |
+
fichier: str
|
14 |
+
Le chemin d'accès au fichier audio.
|
15 |
+
|
16 |
+
device: str
|
17 |
+
GPU ou CPU
|
18 |
+
|
19 |
+
Return
|
20 |
+
----------
|
21 |
+
transcript: str
|
22 |
+
Le texte transcrit.
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
### assurer reproducibilité
|
27 |
+
set_seed(2024)
|
28 |
+
|
29 |
+
start_time = time.time()
|
30 |
+
|
31 |
+
### charger le modèle de transcription
|
32 |
+
model_id = "anyantudre/wav2vec2-large-mms-1b-mos-V1"
|
33 |
+
|
34 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
35 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="mos", device=device, ignore_mismatched_sizes=True)
|
36 |
+
|
37 |
+
### preprocessing de l'audio
|
38 |
+
signal, sampling_rate = librosa.load(fichier, sr=16000)
|
39 |
+
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt", padding=True)
|
40 |
+
|
41 |
+
### faire l'inference
|
42 |
+
with torch.no_grad():
|
43 |
+
outputs = model(**inputs).logits
|
44 |
+
|
45 |
+
pred_ids = torch.argmax(outputs, dim=-1)[0]
|
46 |
+
transcription = processor.decode(pred_ids)
|
47 |
+
|
48 |
+
print("Temps écoulé: ", int(time.time() - start_time), " secondes")
|
49 |
+
return transcription
|
speech_to_text.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
import librosa
|
2 |
-
import torch
|
3 |
-
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
4 |
-
from transformers import set_seed
|
5 |
-
import time
|
6 |
-
|
7 |
-
|
8 |
-
def transcribe(fp:str, target_lang:str) -> str:
|
9 |
-
'''
|
10 |
-
For given audio file, transcribe it.
|
11 |
-
|
12 |
-
Parameters
|
13 |
-
----------
|
14 |
-
fp: str
|
15 |
-
The file path to the audio file.
|
16 |
-
target_lang:str
|
17 |
-
The ISO-3 code of the target language.
|
18 |
-
|
19 |
-
Returns
|
20 |
-
----------
|
21 |
-
transcript:str
|
22 |
-
The transcribed text.
|
23 |
-
'''
|
24 |
-
# Ensure replicability
|
25 |
-
set_seed(555)
|
26 |
-
start_time = time.time()
|
27 |
-
|
28 |
-
# Load transcription model
|
29 |
-
model_id = "facebook/mms-1b-all"
|
30 |
-
|
31 |
-
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
|
32 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)
|
33 |
-
|
34 |
-
# Process the audio
|
35 |
-
signal, sampling_rate = librosa.load(fp, sr=16000)
|
36 |
-
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt")
|
37 |
-
|
38 |
-
# Inference
|
39 |
-
with torch.no_grad():
|
40 |
-
outputs = model(**inputs).logits
|
41 |
-
|
42 |
-
ids = torch.argmax(outputs, dim=-1)[0]
|
43 |
-
transcript = processor.decode(ids)
|
44 |
-
|
45 |
-
print("Time elapsed: ", int(time.time() - start_time), " seconds")
|
46 |
-
return transcript
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|