gpt-omni commited on
Commit
411819d
1 Parent(s): 9b186d7
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -30,7 +30,7 @@ import soundfile as sf
30
  from litgpt.model import GPT, Config
31
  from lightning.fabric.utilities.load import _lazy_load as lazy_load
32
  from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
33
- from utils.snac_utils import get_snac, generate_audio_data
34
  import whisper
35
  from tqdm import tqdm
36
  from huggingface_hub import snapshot_download
@@ -80,19 +80,19 @@ if not os.path.exists(ckpt_dir):
80
  snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
81
  whispermodel = whisper.load_model("small").to(device)
82
  text_tokenizer = Tokenizer(ckpt_dir)
83
- fabric = L.Fabric(devices=1, strategy="auto")
84
  config = Config.from_file(ckpt_dir + "/model_config.yaml")
85
  config.post_adapter = False
86
 
87
  model = GPT(config, device=device)
88
 
89
- # model = fabric.setup(model)
90
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
91
  model.load_state_dict(state_dict, strict=True)
92
  model = model.to(device)
93
  model.eval()
94
 
95
 
 
96
  def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
97
  with torch.no_grad():
98
  mel = mel.unsqueeze(0).to(device)
@@ -128,6 +128,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
128
  return torch.stack([audio_feature, audio_feature]), stacked_inputids
129
 
130
 
 
131
  def next_token_batch(
132
  model: GPT,
133
  audio_features: torch.tensor,
@@ -162,9 +163,19 @@ def load_audio(path):
162
  mel = whisper.log_mel_spectrogram(audio)
163
  return mel, int(duration_ms / 20) + 1
164
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  # @torch.inference_mode()
167
- @spaces.GPU
168
  def run_AT_batch_stream(
169
  audio_path,
170
  stream_stride=4,
@@ -178,11 +189,10 @@ def run_AT_batch_stream(
178
 
179
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
180
 
181
- # with self.fabric.init_tensor():
182
  model.set_kv_cache(batch_size=2)
183
 
184
  mel, leng = load_audio(audio_path)
185
- audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
186
  T = input_ids[0].size(1)
187
  device = input_ids[0].device
188
 
 
30
  from litgpt.model import GPT, Config
31
  from lightning.fabric.utilities.load import _lazy_load as lazy_load
32
  from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
33
+ from utils.snac_utils import get_snac
34
  import whisper
35
  from tqdm import tqdm
36
  from huggingface_hub import snapshot_download
 
80
  snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
81
  whispermodel = whisper.load_model("small").to(device)
82
  text_tokenizer = Tokenizer(ckpt_dir)
83
+ # fabric = L.Fabric(devices=1, strategy="auto")
84
  config = Config.from_file(ckpt_dir + "/model_config.yaml")
85
  config.post_adapter = False
86
 
87
  model = GPT(config, device=device)
88
 
 
89
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
90
  model.load_state_dict(state_dict, strict=True)
91
  model = model.to(device)
92
  model.eval()
93
 
94
 
95
+ @spaces.GPU
96
  def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
97
  with torch.no_grad():
98
  mel = mel.unsqueeze(0).to(device)
 
128
  return torch.stack([audio_feature, audio_feature]), stacked_inputids
129
 
130
 
131
+ @spaces.GPU
132
  def next_token_batch(
133
  model: GPT,
134
  audio_features: torch.tensor,
 
163
  mel = whisper.log_mel_spectrogram(audio)
164
  return mel, int(duration_ms / 20) + 1
165
 
166
+
167
+ @spaces.GPU
168
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
169
+ audio = reconstruct_tensors(snac_tokens, device)
170
+ with torch.inference_mode():
171
+ audio_hat = snacmodel.decode(audio)
172
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
173
+ audio_data = audio_data.astype(np.int16)
174
+ audio_data = audio_data.tobytes()
175
+ return audio_data
176
+
177
 
178
  # @torch.inference_mode()
 
179
  def run_AT_batch_stream(
180
  audio_path,
181
  stream_stride=4,
 
189
 
190
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
191
 
 
192
  model.set_kv_cache(batch_size=2)
193
 
194
  mel, leng = load_audio(audio_path)
195
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
196
  T = input_ids[0].size(1)
197
  device = input_ids[0].device
198