hugo flores garcia commited on
Commit
05d43c6
1 Parent(s): ee4b45b

app/interface fixes

Browse files
Files changed (3) hide show
  1. app.py +18 -47
  2. token_telephone/vamp_helper.py +1 -1
  3. vampnet/interface.py +6 -3
app.py CHANGED
@@ -19,38 +19,11 @@ from vampnet import mask as pmask
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  interface = Interface.default()
22
-
23
- # populate the model choices with any interface.yml files in the generated confs
24
- MODEL_CHOICES = {
25
- "default": {
26
- "Interface.coarse_ckpt": str(interface.coarse_path),
27
- "Interface.coarse2fine_ckpt": str(interface.c2f_path),
28
- "Interface.codec_ckpt": str(interface.codec_path),
29
- }
30
- }
31
- generated_confs = Path("conf/generated")
32
- for conf_file in generated_confs.glob("*/interface.yml"):
33
- with open(conf_file) as f:
34
- _conf = yaml.safe_load(f)
35
-
36
- # check if the coarse, c2f, and codec ckpts exist
37
- # otherwise, dont' add this model choice
38
- if not (
39
- Path(_conf["Interface.coarse_ckpt"]).exists() and
40
- Path(_conf["Interface.coarse2fine_ckpt"]).exists() and
41
- Path(_conf["Interface.codec_ckpt"]).exists()
42
- ):
43
- continue
44
-
45
- MODEL_CHOICES[conf_file.parent.name] = _conf
46
-
47
 
48
  def to_output(sig):
49
  return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
50
 
51
-
52
-
53
- MAX_DURATION_S = 5
54
  def load_audio(file):
55
  print(file)
56
  if isinstance(file, str):
@@ -91,6 +64,7 @@ def _vamp(
91
  typical_mass, typical_min_tokens, top_p,
92
  sample_cutoff, stretch_factor, api=False
93
  ):
 
94
  t0 = time.time()
95
  interface.to("cuda" if torch.cuda.is_available() else "cpu")
96
  print(f"using device {interface.device}")
@@ -105,15 +79,15 @@ def _vamp(
105
  sig = at.AudioSignal(input_audio, sr)
106
 
107
  # reload the model if necessary
108
- interface.reload(
109
- coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
110
- c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
111
- )
112
 
113
  if pitch_shift_amt != 0:
114
  sig = shift_pitch(sig, pitch_shift_amt)
115
 
116
- build_mask_kwargs = dict(
 
 
 
117
  rand_mask_intensity=1.0,
118
  prefix_s=0.0,
119
  suffix_s=0.0,
@@ -124,29 +98,26 @@ def _vamp(
124
  upper_codebook_mask=int(n_mask_codebooks),
125
  )
126
 
127
- vamp_kwargs = dict(
128
- temperature=sampletemp,
129
- typical_filtering=typical_filtering,
130
- typical_mass=typical_mass,
131
- typical_min_tokens=typical_min_tokens,
132
- top_p=None,
133
- seed=_seed,
134
- sample_cutoff=1.0,
135
- )
136
 
137
  # save the mask as a txt file
138
  interface.set_chunk_size(10.0)
139
- sig, mask, codes = interface.ez_vamp(
140
- sig,
141
  batch_size=1 if api else 1,
142
  feedback_steps=1,
143
  time_stretch_factor=stretch_factor,
144
- build_mask_kwargs=build_mask_kwargs,
145
- vamp_kwargs=vamp_kwargs,
146
  return_mask=True,
 
 
 
 
 
 
 
147
  )
148
  print(f"vamp took {time.time() - t0} seconds")
149
 
 
150
 
151
  return to_output(sig)
152
 
@@ -352,7 +323,7 @@ with gr.Blocks() as demo:
352
 
353
  model_choice = gr.Dropdown(
354
  label="model choice",
355
- choices=list(MODEL_CHOICES.keys()),
356
  value="default",
357
  visible=True
358
  )
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  interface = Interface.default()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def to_output(sig):
24
  return sig.sample_rate, sig.cpu().detach().numpy()[0][0]
25
 
26
+ MAX_DURATION_S = 10
 
 
27
  def load_audio(file):
28
  print(file)
29
  if isinstance(file, str):
 
64
  typical_mass, typical_min_tokens, top_p,
65
  sample_cutoff, stretch_factor, api=False
66
  ):
67
+
68
  t0 = time.time()
69
  interface.to("cuda" if torch.cuda.is_available() else "cpu")
70
  print(f"using device {interface.device}")
 
79
  sig = at.AudioSignal(input_audio, sr)
80
 
81
  # reload the model if necessary
82
+ interface.load_finetuned(model_choice)
 
 
 
83
 
84
  if pitch_shift_amt != 0:
85
  sig = shift_pitch(sig, pitch_shift_amt)
86
 
87
+ codes = interface.encode(sig)
88
+
89
+ mask = interface.build_mask(
90
+ codes, sig,
91
  rand_mask_intensity=1.0,
92
  prefix_s=0.0,
93
  suffix_s=0.0,
 
98
  upper_codebook_mask=int(n_mask_codebooks),
99
  )
100
 
 
 
 
 
 
 
 
 
 
101
 
102
  # save the mask as a txt file
103
  interface.set_chunk_size(10.0)
104
+ codes, mask = interface.vamp(
105
+ codes, mask,
106
  batch_size=1 if api else 1,
107
  feedback_steps=1,
108
  time_stretch_factor=stretch_factor,
 
 
109
  return_mask=True,
110
+ temperature=sampletemp,
111
+ typical_filtering=typical_filtering,
112
+ typical_mass=typical_mass,
113
+ typical_min_tokens=typical_min_tokens,
114
+ top_p=None,
115
+ seed=_seed,
116
+ sample_cutoff=1.0,
117
  )
118
  print(f"vamp took {time.time() - t0} seconds")
119
 
120
+ sig = interface.decode(codes)
121
 
122
  return to_output(sig)
123
 
 
323
 
324
  model_choice = gr.Dropdown(
325
  label="model choice",
326
+ choices=list(interface.available_models()),
327
  value="default",
328
  visible=True
329
  )
token_telephone/vamp_helper.py CHANGED
@@ -136,7 +136,7 @@ def ez_variation(
136
 
137
  # save the mask as a txt file
138
  interface.set_chunk_size(10.0)
139
- sig, mask, codes = interface.ez_vamp(
140
  sig,
141
  batch_size=1,
142
  feedback_steps=1,
 
136
 
137
  # save the mask as a txt file
138
  interface.set_chunk_size(10.0)
139
+ sig, mask, codes = interface.vamp(
140
  sig,
141
  batch_size=1,
142
  feedback_steps=1,
vampnet/interface.py CHANGED
@@ -128,13 +128,16 @@ class Interface(torch.nn.Module):
128
  @classmethod
129
  def available_models(cls):
130
  from . import list_finetuned
131
- return list_finetuned()
132
 
133
 
134
  def load_finetuned(self, name: str):
135
  assert name in self.available_models(), f"{name} is not a valid model name"
136
- from . import download_finetuned
137
- coarse_path, c2f_path = download_finetuned(name)
 
 
 
138
  self.reload(
139
  coarse_ckpt=coarse_path,
140
  c2f_ckpt=c2f_path,
 
128
  @classmethod
129
  def available_models(cls):
130
  from . import list_finetuned
131
+ return list_finetuned() + ["default"]
132
 
133
 
134
  def load_finetuned(self, name: str):
135
  assert name in self.available_models(), f"{name} is not a valid model name"
136
+ from . import download_finetuned, download_default
137
+ if name == "default":
138
+ coarse_path, c2f_path = download_default()
139
+ else:
140
+ coarse_path, c2f_path = download_finetuned(name)
141
  self.reload(
142
  coarse_ckpt=coarse_path,
143
  c2f_ckpt=c2f_path,