Siddhant commited on
Commit
b9d404b
1 Parent(s): 1cf9f4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -22,7 +22,6 @@ import gradio as gr
22
  # import os
23
 
24
  # os.system('python -m unidic download')
25
- from transformers import pipeline
26
  import numpy as np
27
  from VAD.vad_iterator import VADIterator
28
  import torch
@@ -47,7 +46,7 @@ user_role = "user"
47
  tts_model = TTS(language="EN_NEWEST", device="auto")
48
  speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
49
  blocksize = 512
50
-
51
  def int2float(sound):
52
  """
53
  Taken from https://github.com/snakers4/silero-vad
@@ -108,16 +107,17 @@ def transcribe(stream, new_chunk):
108
  array = torch.cat(vad_output).cpu().numpy()
109
  duration_ms = len(array) / sr * 1000
110
  if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
111
- input_features = ASR_processor(
112
- array, sampling_rate=16000, return_tensors="pt"
113
- ).input_features
114
- print(input_features)
115
- input_features = input_features.to("cpu", dtype=getattr(torch, "float16"))
116
- pred_ids = ASR_model.generate(input_features, max_new_tokens=128, min_new_tokens=0, num_beams=1, return_timestamps=False,task="transcribe",language="en")
117
- print(pred_ids)
118
- prompt = ASR_processor.batch_decode(
119
- pred_ids, skip_special_tokens=True, decode_with_timestamps=False
120
- )[0]
 
121
  print(prompt)
122
  # prompt=ASR_model.transcribe(array)["text"].strip()
123
  chat.append({"role": user_role, "content": prompt})
 
22
  # import os
23
 
24
  # os.system('python -m unidic download')
 
25
  import numpy as np
26
  from VAD.vad_iterator import VADIterator
27
  import torch
 
46
  tts_model = TTS(language="EN_NEWEST", device="auto")
47
  speaker_id = tts_model.hps.data.spk2id["EN-Newest"]
48
  blocksize = 512
49
+ transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
50
  def int2float(sound):
51
  """
52
  Taken from https://github.com/snakers4/silero-vad
 
107
  array = torch.cat(vad_output).cpu().numpy()
108
  duration_ms = len(array) / sr * 1000
109
  if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
110
+ # input_features = ASR_processor(
111
+ # array, sampling_rate=16000, return_tensors="pt"
112
+ # ).input_features
113
+ # print(input_features)
114
+ # input_features = input_features.to("cpu", dtype=getattr(torch, "float16"))
115
+ # pred_ids = ASR_model.generate(input_features, max_new_tokens=128, min_new_tokens=0, num_beams=1, return_timestamps=False,task="transcribe",language="en")
116
+ # print(pred_ids)
117
+ # prompt = ASR_processor.batch_decode(
118
+ # pred_ids, skip_special_tokens=True, decode_with_timestamps=False
119
+ # )[0]
120
+ prompt=transcriber({"sampling_rate": sr, "raw": array})["text"]
121
  print(prompt)
122
  # prompt=ASR_model.transcribe(array)["text"].strip()
123
  chat.append({"role": user_role, "content": prompt})