Spaces:
Runtime error
Runtime error
import os | |
from datasets import load_dataset | |
import json | |
import uuid | |
from pathlib import Path | |
import json | |
from datasets import load_dataset | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from flask_apscheduler import APScheduler | |
import shutil | |
from PIL import Image | |
import sqlite3 | |
from huggingface_hub import Repository | |
app = Flask(__name__, static_url_path='/static') | |
CORS(app) | |
TOKEN = os.environ.get('dataset_token') | |
DB_FILE = Path("./prompts.db") | |
repo = Repository( | |
local_dir="data", | |
repo_type="dataset", | |
clone_from="huggingface-projects/wordalle_guesses", | |
use_auth_token=TOKEN | |
) | |
repo.git_pull() | |
# copy db on db to local path | |
shutil.copyfile("./data/prompts.db", DB_FILE) | |
dataset = load_dataset( | |
"huggingface-projects/wordalle_prompts", | |
use_auth_token=TOKEN) | |
Path("static/images").mkdir(parents=True, exist_ok=True) | |
db = sqlite3.connect(DB_FILE) | |
try: | |
data = db.execute("SELECT * FROM prompts").fetchall() | |
db.close() | |
except sqlite3.OperationalError: | |
db.execute('CREATE TABLE prompts (guess TEXT, correct TEXT)') | |
db.commit() | |
# extract images and prompts from dataset and save to dis | |
data = {} | |
for row in dataset['train']: | |
prompt = dataset['train'].features['label'].int2str(row['label']) | |
image = row['image'] | |
hash = uuid.uuid4().hex | |
image_file = Path(f'static/images/{hash}.jpg') | |
image_compress = image.resize((136, 136), Image.Resampling.LANCZOS) | |
image_compress.save(image_file, optimize=True, quality=95) | |
if prompt not in data: | |
data[prompt] = [] | |
data[prompt].append(str(image_file)) | |
with open('static/data.json', 'w') as f: | |
json.dump(data, f) | |
def update_repository(): | |
repo.git_pull() | |
# copy db on db to local path | |
shutil.copyfile(DB_FILE, "./data/prompts.db") | |
with sqlite3.connect("./data/prompts.db") as db: | |
db.row_factory = sqlite3.Row | |
result = db.execute("SELECT * FROM prompts").fetchall() | |
data = [dict(row) for row in result] | |
with open('./data/data.json', 'w') as f: | |
json.dump(data, f, separators=(',', ':')) | |
print("Updating repository") | |
repo.push_to_hub(blocking=False) | |
def index(): | |
return app.send_static_file('index.html') | |
def push(): | |
if(request.headers['token'] == TOKEN): | |
print("Force Push repository") | |
shutil.copyfile(DB_FILE, "./data/prompts.db") | |
oldpwd = os.getcwd() | |
os.chdir("./data") | |
os.system("git add .") | |
os.system("git commit -m 'force push'") | |
os.system("git push --force") | |
os.chdir(oldpwd) | |
return "Success", 200 | |
else: | |
return "Error", 401 | |
def getdata(): | |
return app.send_static_file('data.json') | |
def create(): | |
if request.method == 'POST': | |
try: | |
data = request.get_json() | |
guess = data['guess'] | |
correct = data['correct'] | |
with sqlite3.connect(DB_FILE) as db: | |
db.execute( | |
'INSERT INTO prompts (guess, correct) VALUES (?, ?)', (guess, correct)) | |
db.commit() | |
return 'OK', 200 | |
except: | |
return 'Missing guess or correct', 400 | |
if __name__ == '__main__': | |
mode = os.environ.get('FLASK_ENV', 'production') | |
print(mode) | |
dev = mode == 'development' | |
if not dev: | |
scheduler = APScheduler() | |
scheduler.add_job(id='Update Dataset Repository', | |
func=update_repository, trigger='interval', seconds=300) | |
scheduler.start() | |
app.run(host='0.0.0.0', port=int( | |
os.environ.get('PORT', 7860)), debug=True, use_reloader=dev) | |