Review DB management

#1
by Wauplin HF staff - opened
Files changed (1) hide show
  1. app.py +77 -85
app.py CHANGED
@@ -7,8 +7,48 @@ import sqlite3
7
  from datasets import load_dataset
8
  import threading
9
  import time
10
- from huggingface_hub import HfApi
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  MUST_BE_LOGGEDIN = "Please login with Hugging Face to participate in the TTS Arena."
13
  DESCR = """
14
  # TTS Arena
@@ -25,11 +65,11 @@ INSTR = """
25
  **When you're ready to begin, click the Start button below!** The model names will be revealed once you vote.
26
  """.strip()
27
  request = ''
28
- if os.getenv('HF_ID'):
29
  request = f"""
30
  ### Request Model
31
 
32
- Please fill out [this form](https://huggingface.co/spaces/{os.getenv('HF_ID')}/discussions/new?title=%5BModel+Request%5D+&description=%23%23%20Model%20Request%0A%0A%2A%2AModel%20website%2Fpaper%20%28if%20applicable%29%2A%2A%3A%0A%2A%2AModel%20available%20on%2A%2A%3A%20%28coqui%7CHF%20pipeline%7Ccustom%20code%29%0A%2A%2AWhy%20do%20you%20want%20this%20model%20added%3F%2A%2A%0A%2A%2AComments%3A%2A%2A) to request a model.
33
  """
34
  ABOUT = f"""
35
  ## About
@@ -57,28 +97,29 @@ A list of the models, based on how highly they are ranked!
57
  """.strip()
58
 
59
 
60
- dataset = load_dataset("ttseval/tts-arena-new", token=os.getenv('HF_TOKEN'))
61
- def reload_db():
62
- global dataset
63
- dataset = load_dataset("ttseval/tts-arena-new", token=os.getenv('HF_TOKEN'))
64
- return 'Reload Dataset'
 
 
65
  def del_db(txt):
66
  if not txt.lower() == 'delete db':
67
  raise gr.Error('You did not enter "delete db"')
68
- api = HfApi(
69
- token=os.getenv('HF_TOKEN')
70
- )
71
- os.remove('database.db')
72
- create_db()
73
- api.delete_file(
74
- path_in_repo='database.db',
75
- repo_id=os.getenv('DATASET_ID'),
76
- repo_type='dataset'
77
- )
78
  return 'Delete DB'
 
79
  theme = gr.themes.Base(
80
  font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
81
  )
 
82
  model_names = {
83
  'styletts2': 'StyleTTS 2',
84
  'tacotron': 'Tacotron',
@@ -126,14 +167,15 @@ model_licenses = {
126
  'speecht5': 'MIT',
127
  }
128
  # def get_random_split(existing_split=None):
129
- # choice = random.choice(list(dataset.keys()))
130
  # if existing_split and choice == existing_split:
131
  # return get_random_split(choice)
132
  # else:
133
  # return choice
134
  def get_db():
135
- return sqlite3.connect('database.db')
136
- def create_db():
 
137
  conn = get_db()
138
  cursor = conn.cursor()
139
  cursor.execute('''
@@ -152,7 +194,7 @@ def create_db():
152
  );
153
  ''')
154
 
155
- def get_data():
156
  conn = get_db()
157
  cursor = conn.cursor()
158
  cursor.execute('SELECT name, upvote, downvote FROM model WHERE (upvote + downvote) > 5')
@@ -193,10 +235,10 @@ def upvote_model(model, uname):
193
  if cursor.rowcount == 0:
194
  cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 1, 0)', (model,))
195
  cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, 1,))
196
- conn.commit()
 
197
  cursor.close()
198
 
199
-
200
  def downvote_model(model, uname):
201
  conn = get_db()
202
  cursor = conn.cursor()
@@ -204,8 +246,10 @@ def downvote_model(model, uname):
204
  if cursor.rowcount == 0:
205
  cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 0, 1)', (model,))
206
  cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, -1,))
207
- conn.commit()
 
208
  cursor.close()
 
209
  def a_is_better(model1, model2, profile: gr.OAuthProfile | None):
210
  if not profile:
211
  raise gr.Error(MUST_BE_LOGGEDIN)
@@ -236,8 +280,8 @@ def both_good(model1, model2, profile: gr.OAuthProfile | None):
236
  return reload(model1, model2)
237
  def reload(chosenmodel1=None, chosenmodel2=None):
238
  # Select random splits
239
- row = random.choice(list(dataset['train']))
240
- options = list(random.choice(list(dataset['train'])).keys())
241
  split1, split2 = random.sample(options, 2)
242
  choice1, choice2 = (row[split1], row[split2])
243
  if chosenmodel1 in model_names:
@@ -256,11 +300,11 @@ def reload(chosenmodel1=None, chosenmodel2=None):
256
 
257
  with gr.Blocks() as leaderboard:
258
  gr.Markdown(LDESC)
259
- # df = gr.Dataframe(interactive=False, value=get_data())
260
  df = gr.Dataframe(interactive=False, min_width=0, wrap=True, column_widths=[30, 200, 50, 75, 50])
261
  reloadbtn = gr.Button("Refresh")
262
- leaderboard.load(get_data, outputs=[df])
263
- reloadbtn.click(get_data, outputs=[df])
264
  gr.Markdown("DISCLAIMER: The licenses listed may not be accurate or up to date, you are responsible for checking the licenses before using the models. Also note that some models may have additional usage restrictions.")
265
 
266
  with gr.Blocks() as vote:
@@ -310,8 +354,8 @@ with gr.Blocks() as vote:
310
  with gr.Blocks() as about:
311
  gr.Markdown(ABOUT)
312
  with gr.Blocks() as admin:
313
- rdb = gr.Button("Reload Dataset")
314
- rdb.click(reload_db, outputs=rdb)
315
  with gr.Group():
316
  dbtext = gr.Textbox(label="Type \"delete db\" to confirm", placeholder="delete db")
317
  ddb = gr.Button("Delete DB")
@@ -319,57 +363,5 @@ with gr.Blocks() as admin:
319
  with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="TTS Leaderboard") as demo:
320
  gr.Markdown(DESCR)
321
  gr.TabbedInterface([vote, leaderboard, about, admin], ['Vote', 'Leaderboard', 'About', 'Admin (ONLY IN BETA)'])
322
- def restart_space():
323
- api = HfApi(
324
- token=os.getenv('HF_TOKEN')
325
- )
326
- time.sleep(60 * 60) # Every hour
327
- print("Syncing DB before restarting space")
328
- api.upload_file(
329
- path_or_fileobj='database.db',
330
- path_in_repo='database.db',
331
- repo_id=os.getenv('DATASET_ID'),
332
- repo_type='dataset'
333
- )
334
- print("Restarting space")
335
- api.restart_space(repo_id=os.getenv('HF_ID'))
336
- def sync_db():
337
- api = HfApi(
338
- token=os.getenv('HF_TOKEN')
339
- )
340
- while True:
341
- time.sleep(60 * 10)
342
- print("Uploading DB")
343
- api.upload_file(
344
- path_or_fileobj='database.db',
345
- path_in_repo='database.db',
346
- repo_id=os.getenv('DATASET_ID'),
347
- repo_type='dataset'
348
- )
349
- if os.getenv('HF_ID'):
350
- restart_thread = threading.Thread(target=restart_space)
351
- restart_thread.daemon = True
352
- restart_thread.start()
353
- if os.getenv('DATASET_ID'):
354
- # Fetch DB
355
- api = HfApi(
356
- token=os.getenv('HF_TOKEN')
357
- )
358
- print("Downloading DB...")
359
- try:
360
- path = api.hf_hub_download(
361
- repo_id=os.getenv('DATASET_ID'),
362
- repo_type='dataset',
363
- filename='database.db',
364
- cache_dir='./'
365
- )
366
- shutil.copyfile(path, 'database.db')
367
- print("Downloaded DB")
368
- except:
369
- pass
370
- # Update DB
371
- db_thread = threading.Thread(target=sync_db)
372
- db_thread.daemon = True
373
- db_thread.start()
374
- create_db()
375
  demo.queue(api_open=False).launch(show_api=False)
 
7
  from datasets import load_dataset
8
  import threading
9
  import time
10
+ from pathlib import Path
11
+ from huggingface_hub import CommitScheduler, delete_file, hf_hub_download
12
 
13
+ SPACE_ID = os.getenv('HF_ID')
14
+
15
+ DB_DATASET_ID = os.getenv('DATASET_ID')
16
+ DB_NAME = "database.db"
17
+ DB_PATH = "database.db"
18
+
19
+ AUDIO_DATASET_ID = "ttseval/tts-arena-new"
20
+
21
+ ####################################
22
+ # Space initialization
23
+ ####################################
24
+
25
+ # Download existing DB
26
+ print("Downloading DB...")
27
+ try:
28
+ cache_path = hf_hub_download(repo_id=DB_DATASET_ID, repo_type='dataset', filename=DB_NAME)
29
+ shutil.copyfile(cache_path, DB_PATH)
30
+ print("Downloaded DB")
31
+ except Exception as e:
32
+ print("Error while downloading DB:", e)
33
+
34
+ # Create DB table (if doesn't exist)
35
+ create_db_if_missing()
36
+
37
+ # Sync local DB with remote repo every 5 minute (only if a change is detected)
38
+ scheduler = CommitScheduler(
39
+ repo_id=DB_DATASET_ID,
40
+ repo_type="dataset",
41
+ folder_path=Path(DB_PATH).parent,
42
+ every=5,
43
+ allow_patterns=DB_NAME,
44
+ )
45
+
46
+ # Load audio dataset
47
+ audio_dataset = load_dataset(AUDIO_DATASET_ID)
48
+
49
+ ####################################
50
+ # Gradio app
51
+ ####################################
52
  MUST_BE_LOGGEDIN = "Please login with Hugging Face to participate in the TTS Arena."
53
  DESCR = """
54
  # TTS Arena
 
65
  **When you're ready to begin, click the Start button below!** The model names will be revealed once you vote.
66
  """.strip()
67
  request = ''
68
+ if SPACE_ID:
69
  request = f"""
70
  ### Request Model
71
 
72
+ Please fill out [this form](https://huggingface.co/spaces/{SPACE_ID}/discussions/new?title=%5BModel+Request%5D+&description=%23%23%20Model%20Request%0A%0A%2A%2AModel%20website%2Fpaper%20%28if%20applicable%29%2A%2A%3A%0A%2A%2AModel%20available%20on%2A%2A%3A%20%28coqui%7CHF%20pipeline%7Ccustom%20code%29%0A%2A%2AWhy%20do%20you%20want%20this%20model%20added%3F%2A%2A%0A%2A%2AComments%3A%2A%2A) to request a model.
73
  """
74
  ABOUT = f"""
75
  ## About
 
97
  """.strip()
98
 
99
 
100
+
101
+
102
+ def reload_audio_dataset():
103
+ global audio_dataset
104
+ audio_dataset = load_dataset(AUDIO_DATASET_ID)
105
+ return 'Reload audio dataset'
106
+
107
  def del_db(txt):
108
  if not txt.lower() == 'delete db':
109
  raise gr.Error('You did not enter "delete db"')
110
+
111
+ # Delete local + remote
112
+ os.remove(DB_PATH)
113
+ delete_file(path_in_repo=DB_NAME, repo_id=DATASET_ID, repo_type='dataset')
114
+
115
+ # Recreate
116
+ create_db_if_missing()
 
 
 
117
  return 'Delete DB'
118
+
119
  theme = gr.themes.Base(
120
  font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
121
  )
122
+
123
  model_names = {
124
  'styletts2': 'StyleTTS 2',
125
  'tacotron': 'Tacotron',
 
167
  'speecht5': 'MIT',
168
  }
169
  # def get_random_split(existing_split=None):
170
+ # choice = random.choice(list(audio_dataset.keys()))
171
  # if existing_split and choice == existing_split:
172
  # return get_random_split(choice)
173
  # else:
174
  # return choice
175
  def get_db():
176
+ return sqlite3.connect(DB_PATH)
177
+
178
+ def create_db_if_missing():
179
  conn = get_db()
180
  cursor = conn.cursor()
181
  cursor.execute('''
 
194
  );
195
  ''')
196
 
197
+ def get_leaderboard():
198
  conn = get_db()
199
  cursor = conn.cursor()
200
  cursor.execute('SELECT name, upvote, downvote FROM model WHERE (upvote + downvote) > 5')
 
235
  if cursor.rowcount == 0:
236
  cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 1, 0)', (model,))
237
  cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, 1,))
238
+ with scheduler.lock:
239
+ conn.commit()
240
  cursor.close()
241
 
 
242
  def downvote_model(model, uname):
243
  conn = get_db()
244
  cursor = conn.cursor()
 
246
  if cursor.rowcount == 0:
247
  cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 0, 1)', (model,))
248
  cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, -1,))
249
+ with scheduler.lock:
250
+ conn.commit()
251
  cursor.close()
252
+
253
  def a_is_better(model1, model2, profile: gr.OAuthProfile | None):
254
  if not profile:
255
  raise gr.Error(MUST_BE_LOGGEDIN)
 
280
  return reload(model1, model2)
281
  def reload(chosenmodel1=None, chosenmodel2=None):
282
  # Select random splits
283
+ row = random.choice(list(audio_dataset['train']))
284
+ options = list(random.choice(list(audio_dataset['train'])).keys())
285
  split1, split2 = random.sample(options, 2)
286
  choice1, choice2 = (row[split1], row[split2])
287
  if chosenmodel1 in model_names:
 
300
 
301
  with gr.Blocks() as leaderboard:
302
  gr.Markdown(LDESC)
303
+ # df = gr.Dataframe(interactive=False, value=get_leaderboard())
304
  df = gr.Dataframe(interactive=False, min_width=0, wrap=True, column_widths=[30, 200, 50, 75, 50])
305
  reloadbtn = gr.Button("Refresh")
306
+ leaderboard.load(get_leaderboard, outputs=[df])
307
+ reloadbtn.click(get_leaderboard, outputs=[df])
308
  gr.Markdown("DISCLAIMER: The licenses listed may not be accurate or up to date, you are responsible for checking the licenses before using the models. Also note that some models may have additional usage restrictions.")
309
 
310
  with gr.Blocks() as vote:
 
354
  with gr.Blocks() as about:
355
  gr.Markdown(ABOUT)
356
  with gr.Blocks() as admin:
357
+ rdb = gr.Button("Reload Audio Dataset")
358
+ rdb.click(reload_audio_dataset, outputs=rdb)
359
  with gr.Group():
360
  dbtext = gr.Textbox(label="Type \"delete db\" to confirm", placeholder="delete db")
361
  ddb = gr.Button("Delete DB")
 
363
  with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="TTS Leaderboard") as demo:
364
  gr.Markdown(DESCR)
365
  gr.TabbedInterface([vote, leaderboard, about, admin], ['Vote', 'Leaderboard', 'About', 'Admin (ONLY IN BETA)'])
366
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  demo.queue(api_open=False).launch(show_api=False)