osanseviero commited on
Commit
4fec958
1 Parent(s): 6979234

Add initial model

Browse files
Files changed (3) hide show
  1. .gitattributes +2 -0
  2. hubert_sd.ckpt +3 -0
  3. model.py +47 -0
.gitattributes CHANGED
@@ -15,3 +15,5 @@
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
+ *ckpt* filter=lfs diff=lfs merge=lfs -text
19
+
hubert_sd.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6ae7e9609674c3aa6104a2a037e85789f0b35e47d232e6754c2012b6bf8ca6d
3
+ size 31527386
model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is just an example of what people would submit for inference.
3
+ """
4
+ import os
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from s3prl.downstream.runner import Runner
9
+
10
+ class PreTrainedModel(Runner):
11
+ def __init__(self, path=""):
12
+ """
13
+ Initialize downstream model.
14
+ """
15
+ ckp_file = os.path.join(path, "hubert_sd.ckpt")
16
+ ckp = torch.load(ckp_file, map_location="cpu")
17
+ ckp["Args"].init_ckpt = ckp_file
18
+ ckp["Args"].mode = "inference"
19
+ ckp["Args"].device = "cpu" # Just to try in my computer
20
+ Runner.__init__(self, ckp["Args"], ckp["Config"])
21
+
22
+ def __call__(self, inputs) -> list[int]:
23
+ """
24
+ Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
25
+ default at 16KHz.
26
+ Return: A list with logits.
27
+ """
28
+ for entry in self.all_entries:
29
+ entry.model.eval()
30
+
31
+ inputs = [torch.FloatTensor(inputs)]
32
+
33
+ with torch.no_grad():
34
+ features = self.upstream.model(inputs)
35
+ features = self.featurizer.model(inputs, features)
36
+ preds = self.downstream.model.inference(features, [])
37
+ return preds[0]
38
+
39
+ """
40
+ import io
41
+ import soundfile as sf
42
+ from urllib.request import urlopen
43
+ model = PreTrainedModel()
44
+ url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
45
+ data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
46
+ print(model(data))
47
+ """