Shokoufeh commited on
Commit
fb3f0db
1 Parent(s): 02422d4

Update app.py to load model checkpoint

Browse files
Files changed (1) hide show
  1. app.py +31 -4
app.py CHANGED
@@ -1,7 +1,34 @@
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import gradio as gr
3
+ from model import SGMSE # Adjust according to your model's definition
4
 
5
+ # Load your model
6
+ model_path = "https://huggingface.co/sp-uhh/speech-enhancement-sgmse/resolve/main/pretrained_checkpoints/speech_enhancement/train_vb_29nqe0uh_epoch%3D115.ckpt"
7
+ model = SGMSE() # Initialize your model class
8
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
9
+ model.eval() # Set the model to evaluation mode
10
 
11
+ def enhance_audio(input_audio):
12
+ import torchaudio
13
+
14
+ # Load the input audio file
15
+ waveform, sample_rate = torchaudio.load(input_audio)
16
+
17
+ with torch.no_grad():
18
+ enhanced_waveform = model(waveform)
19
+
20
+ output_path = "enhanced_audio.wav"
21
+ torchaudio.save(output_path, enhanced_waveform.cpu(), sample_rate)
22
+ return output_path
23
+
24
+ # Create the Gradio interface
25
+ iface = gr.Interface(
26
+ fn=enhance_audio,
27
+ inputs=gr.Audio(source="upload", type="filepath"),
28
+ outputs=gr.Audio(type="file"),
29
+ title="Speech Enhancement Model",
30
+ description="Upload a noisy audio file to enhance it using the SGMSE model."
31
+ )
32
+
33
+ if __name__ == "__main__":
34
+ iface.launch()