sindhuhegde commited on
Commit
f0d8178
1 Parent(s): 4b11292

Update app

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. utils/audio_utils.py +74 -72
app.py CHANGED
@@ -33,6 +33,7 @@ use_cuda = torch.cuda.is_available()
33
  batch_size = 12
34
  fps = 25
35
  n_negative_samples = 100
 
36
 
37
  # Initialize the mediapipe holistic keypoint detection model
38
  holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
@@ -420,6 +421,7 @@ def load_rgb_masked_frames(input_frames, kp_dict, asd=False, stride=1, window_fr
420
 
421
  input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
422
  # print("Input images window: ", input_frames.shape) # Tx25x270x480x3
 
423
 
424
  num_frames = input_frames.shape[0]
425
 
 
33
  batch_size = 12
34
  fps = 25
35
  n_negative_samples = 100
36
+ print("Device: ", device)
37
 
38
  # Initialize the mediapipe holistic keypoint detection model
39
  holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
 
421
 
422
  input_frames = np.array([input_frames[i:i+window_frames, :, :] for i in range(0,input_frames.shape[0], stride) if (i+window_frames <= input_frames.shape[0])])
423
  # print("Input images window: ", input_frames.shape) # Tx25x270x480x3
424
+ print("Successfully created masked input frames")
425
 
426
  num_frames = input_frames.shape[0]
427
 
utils/audio_utils.py CHANGED
@@ -9,97 +9,99 @@ warnings.filterwarnings("ignore", category=FutureWarning)
9
 
10
 
11
  audio_opts = {
12
- 'sample_rate': 16000,
13
- 'n_fft': 512,
14
- 'win_length': 320,
15
- 'hop_length': 160,
16
- 'n_mel': 80,
17
  }
18
 
19
 
20
  def load_wav(path, fr=0, to=10000, sample_rate=16000):
21
- """Loads Audio wav from path at time indices given by fr, to (seconds)"""
22
 
23
- _, wav = wavfile.read(path)
24
- fr_aud = int(np.round(fr * sample_rate))
25
- to_aud = int(np.round((to) * sample_rate))
26
 
27
- wav = wav[fr_aud:to_aud]
28
 
29
- return wav
30
 
31
 
32
  def wav2filterbanks(wav, mel_basis=None):
33
- """
34
- :param wav: Tensor b x T
35
- """
36
-
37
- assert len(wav.shape) == 2, 'Need batch of wavs as input'
38
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
- # device = 'cpu'
40
- spect = torch.stft(wav,
41
- n_fft=audio_opts['n_fft'],
42
- hop_length=audio_opts['hop_length'],
43
- win_length=audio_opts['win_length'],
44
- window=torch.hann_window(audio_opts['win_length']).to(device),
45
- center=True,
46
- pad_mode='reflect',
47
- normalized=False,
48
- onesided=True) # b x F x T x 2
49
- spect = spect[:, :, :-1, :]
50
-
51
- # ----- Log filterbanks --------------
52
- # mag spectrogram - # b x F x T
53
- mag = power_spect = torch.norm(spect, dim=-1)
54
- phase = torch.atan2(spect[..., 1], spect[..., 0])
55
- if mel_basis is None:
56
- # Build a Mel filter
57
- mel_basis = torch.from_numpy(
58
- librosa.filters.mel(audio_opts['sample_rate'],
59
- audio_opts['n_fft'],
60
- n_mels=audio_opts['n_mel'],
61
- fmin=0,
62
- fmax=int(audio_opts['sample_rate'] / 2)))
63
- mel_basis = mel_basis.float().to(power_spect.device)
64
- features = torch.log(torch.matmul(mel_basis, power_spect) +
65
- 1e-20) # b x F x T
66
- features = features.permute([0, 2, 1]).contiguous() # b x T x F
67
- # -------------------
68
-
69
- # norm_axis = 1 # normalize every sample over time
70
- # mean = features.mean(dim=norm_axis, keepdim=True) # b x 1 x F
71
- # std_dev = features.std(dim=norm_axis, keepdim=True) # b x 1 x F
72
- # features = (features - mean) / std_dev # b x T x F
73
-
74
- return features, mag, phase, mel_basis
 
 
75
 
76
 
77
  def torch_mag_phase_2_np_complex(mag_spect, phase):
78
- complex_spect_2d = torch.stack(
79
- [mag_spect * torch.cos(phase), mag_spect * torch.sin(phase)], -1)
80
- complex_spect_np = complex_spect_2d.cpu().detach().numpy()
81
- complex_spect_np = complex_spect_np[..., 0] + 1j * complex_spect_np[..., 1]
82
- return complex_spect_np
83
 
84
 
85
  def torch_mag_phase_2_complex_as_2d(mag_spect, phase):
86
- complex_spect_2d = torch.stack(
87
- [mag_spect * torch.cos(phase), mag_spect * torch.sin(phase)], -1)
88
- return complex_spect_2d
89
 
90
 
91
  def torch_phase_from_normalized_complex(spect):
92
- phase = torch.atan2(spect[..., 1], spect[..., 0])
93
- return phase
94
 
95
 
96
  def reconstruct_wav_from_mag_phase(mag, phase):
97
- spect = torch_mag_phase_2_np_complex(mag, phase)
98
- wav = np.stack([
99
- librosa.core.istft(spect[ii],
100
- hop_length=audio_opts['hop_length'],
101
- win_length=audio_opts['win_length'],
102
- center=True) for ii in range(spect.shape[0])
103
- ])
104
-
105
- return wav
 
9
 
10
 
11
  audio_opts = {
12
+ 'sample_rate': 16000,
13
+ 'n_fft': 512,
14
+ 'win_length': 320,
15
+ 'hop_length': 160,
16
+ 'n_mel': 80,
17
  }
18
 
19
 
20
  def load_wav(path, fr=0, to=10000, sample_rate=16000):
21
+ """Loads Audio wav from path at time indices given by fr, to (seconds)"""
22
 
23
+ _, wav = wavfile.read(path)
24
+ fr_aud = int(np.round(fr * sample_rate))
25
+ to_aud = int(np.round((to) * sample_rate))
26
 
27
+ wav = wav[fr_aud:to_aud]
28
 
29
+ return wav
30
 
31
 
32
  def wav2filterbanks(wav, mel_basis=None):
33
+ """
34
+ :param wav: Tensor b x T
35
+ """
36
+
37
+ assert len(wav.shape) == 2, 'Need batch of wavs as input'
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ # device = 'cpu'
40
+ spect = torch.stft(wav,
41
+ return_complex=True,
42
+ n_fft=audio_opts['n_fft'],
43
+ hop_length=audio_opts['hop_length'],
44
+ win_length=audio_opts['win_length'],
45
+ window=torch.hann_window(audio_opts['win_length']).to(device),
46
+ center=True,
47
+ pad_mode='reflect',
48
+ normalized=False,
49
+ onesided=True) # b x F x T x 2
50
+ spect = torch.view_as_real(spect)
51
+ spect = spect[:, :, :-1, :]
52
+
53
+ # ----- Log filterbanks --------------
54
+ # mag spectrogram - # b x F x T
55
+ mag = power_spect = torch.norm(spect, dim=-1)
56
+ phase = torch.atan2(spect[..., 1], spect[..., 0])
57
+ if mel_basis is None:
58
+ # Build a Mel filter
59
+ mel_basis = torch.from_numpy(
60
+ librosa.filters.mel(audio_opts['sample_rate'],
61
+ audio_opts['n_fft'],
62
+ n_mels=audio_opts['n_mel'],
63
+ fmin=0,
64
+ fmax=int(audio_opts['sample_rate'] / 2)))
65
+ mel_basis = mel_basis.float().to(power_spect.device)
66
+ features = torch.log(torch.matmul(mel_basis, power_spect) +
67
+ 1e-20) # b x F x T
68
+ features = features.permute([0, 2, 1]).contiguous() # b x T x F
69
+ # -------------------
70
+
71
+ # norm_axis = 1 # normalize every sample over time
72
+ # mean = features.mean(dim=norm_axis, keepdim=True) # b x 1 x F
73
+ # std_dev = features.std(dim=norm_axis, keepdim=True) # b x 1 x F
74
+ # features = (features - mean) / std_dev # b x T x F
75
+
76
+ return features, mag, phase, mel_basis
77
 
78
 
79
  def torch_mag_phase_2_np_complex(mag_spect, phase):
80
+ complex_spect_2d = torch.stack(
81
+ [mag_spect * torch.cos(phase), mag_spect * torch.sin(phase)], -1)
82
+ complex_spect_np = complex_spect_2d.cpu().detach().numpy()
83
+ complex_spect_np = complex_spect_np[..., 0] + 1j * complex_spect_np[..., 1]
84
+ return complex_spect_np
85
 
86
 
87
  def torch_mag_phase_2_complex_as_2d(mag_spect, phase):
88
+ complex_spect_2d = torch.stack(
89
+ [mag_spect * torch.cos(phase), mag_spect * torch.sin(phase)], -1)
90
+ return complex_spect_2d
91
 
92
 
93
  def torch_phase_from_normalized_complex(spect):
94
+ phase = torch.atan2(spect[..., 1], spect[..., 0])
95
+ return phase
96
 
97
 
98
  def reconstruct_wav_from_mag_phase(mag, phase):
99
+ spect = torch_mag_phase_2_np_complex(mag, phase)
100
+ wav = np.stack([
101
+ librosa.core.istft(spect[ii],
102
+ hop_length=audio_opts['hop_length'],
103
+ win_length=audio_opts['win_length'],
104
+ center=True) for ii in range(spect.shape[0])
105
+ ])
106
+
107
+ return wav