anyantudre commited on
Commit
5c30e04
1 Parent(s): 1ffb58d

Rename speech_to_text.py to goai_stt.py

Browse files
Files changed (2) hide show
  1. goai_stt.py +49 -0
  2. speech_to_text.py +0 -46
goai_stt.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import time
4
+ from transformers import set_seed, Wav2Vec2ForCTC, AutoProcessor
5
+
6
+
7
+ def goai_stt(fichier, device):
8
+ """
9
+ Transcrire un fichier audio donné.
10
+
11
+ Paramètres
12
+ ----------
13
+ fichier: str
14
+ Le chemin d'accès au fichier audio.
15
+
16
+ device: str
17
+ GPU ou CPU
18
+
19
+ Return
20
+ ----------
21
+ transcript: str
22
+ Le texte transcrit.
23
+ """
24
+
25
+
26
+ ### assurer reproducibilité
27
+ set_seed(2024)
28
+
29
+ start_time = time.time()
30
+
31
+ ### charger le modèle de transcription
32
+ model_id = "anyantudre/wav2vec2-large-mms-1b-mos-V1"
33
+
34
+ processor = AutoProcessor.from_pretrained(model_id)
35
+ model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="mos", device=device, ignore_mismatched_sizes=True)
36
+
37
+ ### preprocessing de l'audio
38
+ signal, sampling_rate = librosa.load(fichier, sr=16000)
39
+ inputs = processor(signal, sampling_rate=16_000, return_tensors="pt", padding=True)
40
+
41
+ ### faire l'inference
42
+ with torch.no_grad():
43
+ outputs = model(**inputs).logits
44
+
45
+ pred_ids = torch.argmax(outputs, dim=-1)[0]
46
+ transcription = processor.decode(pred_ids)
47
+
48
+ print("Temps écoulé: ", int(time.time() - start_time), " secondes")
49
+ return transcription
speech_to_text.py DELETED
@@ -1,46 +0,0 @@
1
- import librosa
2
- import torch
3
- from transformers import Wav2Vec2ForCTC, AutoProcessor
4
- from transformers import set_seed
5
- import time
6
-
7
-
8
- def transcribe(fp:str, target_lang:str) -> str:
9
- '''
10
- For given audio file, transcribe it.
11
-
12
- Parameters
13
- ----------
14
- fp: str
15
- The file path to the audio file.
16
- target_lang:str
17
- The ISO-3 code of the target language.
18
-
19
- Returns
20
- ----------
21
- transcript:str
22
- The transcribed text.
23
- '''
24
- # Ensure replicability
25
- set_seed(555)
26
- start_time = time.time()
27
-
28
- # Load transcription model
29
- model_id = "facebook/mms-1b-all"
30
-
31
- processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
32
- model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)
33
-
34
- # Process the audio
35
- signal, sampling_rate = librosa.load(fp, sr=16000)
36
- inputs = processor(signal, sampling_rate=16_000, return_tensors="pt")
37
-
38
- # Inference
39
- with torch.no_grad():
40
- outputs = model(**inputs).logits
41
-
42
- ids = torch.argmax(outputs, dim=-1)[0]
43
- transcript = processor.decode(ids)
44
-
45
- print("Time elapsed: ", int(time.time() - start_time), " seconds")
46
- return transcript