skytnt commited on
Commit
b3f8835
1 Parent(s): a158362
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -118,6 +118,9 @@ def send_msgs(msgs):
118
  return json.dumps(msgs)
119
 
120
 
 
 
 
121
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
122
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
123
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
@@ -129,7 +132,7 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
129
  start_events = len(mid_seq[0])
130
  else:
131
  start_events = 1
132
- t = 8.5e-5 * (gen_events+start_events) ** 2 - 8.5e-5 * start_events ** 2 + 23
133
  if "large" in model_name:
134
  t *= 2
135
  return t
 
118
  return json.dumps(msgs)
119
 
120
 
121
+ def calc_time(x):
122
+ return 5.849e-5*x**2 + 0.04781*x + 0.1168
123
+
124
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
125
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
126
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
 
132
  start_events = len(mid_seq[0])
133
  else:
134
  start_events = 1
135
+ t = calc_time(start_events + gen_events) - calc_time(start_events) + 5
136
  if "large" in model_name:
137
  t *= 2
138
  return t