skytnt commited on
Commit
a3e7293
1 Parent(s): ed10990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -16,6 +16,8 @@ import MIDI
16
  from midi_synthesizer import synthesis
17
  from midi_tokenizer import MIDITokenizer
18
 
 
 
19
  def softmax(x, axis):
20
  x_max = np.amax(x, axis=axis, keepdims=True)
21
  exp_x_shifted = np.exp(x - x_max)
@@ -58,7 +60,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
58
  input_tensor = prompt
59
  input_tensor = input_tensor[None, :, :]
60
  cur_len = input_tensor.shape[1]
61
- bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
62
  with bar:
63
  while cur_len < max_len:
64
  end = False
@@ -204,7 +206,7 @@ if __name__ == "__main__":
204
  parser = argparse.ArgumentParser()
205
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
206
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
207
- parser.add_argument("--max-gen", type=int, default=256, help="max")
208
  opt = parser.parse_args()
209
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
210
  model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
 
16
  from midi_synthesizer import synthesis
17
  from midi_tokenizer import MIDITokenizer
18
 
19
+ in_space = os.getenv("SYSTEM") == "spaces"
20
+
21
  def softmax(x, axis):
22
  x_max = np.amax(x, axis=axis, keepdims=True)
23
  exp_x_shifted = np.exp(x - x_max)
 
60
  input_tensor = prompt
61
  input_tensor = input_tensor[None, :, :]
62
  cur_len = input_tensor.shape[1]
63
+ bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
64
  with bar:
65
  while cur_len < max_len:
66
  end = False
 
206
  parser = argparse.ArgumentParser()
207
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
208
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
209
+ parser.add_argument("--max-gen", type=int, default=512, help="max")
210
  opt = parser.parse_args()
211
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
212
  model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")