Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,380 Bytes
c9cc441 d7a562e c9cc441 721391f c9cc441 d7a562e dd5dfc8 d7a562e dd5dfc8 721391f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
from flask_socketio import SocketIO, emit
from flask_cors import CORS
import io
import os
import argparse
from PIL import Image
import torch
import gc
from peft import PeftModel
import queue
import threading
import uuid
import concurrent.futures
from scripts.process_utils import *
app = Flask(__name__)
# app.secret_key = 'super_secret_key'
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*")
# タスクキューの作成
task_queue = queue.Queue()
active_tasks = {}
task_futures = {}
# ThreadPoolExecutorの作成
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
class Task:
def __init__(self, task_id, mode, weight1, weight2, file_data):
self.task_id = task_id
self.mode = mode
self.weight1 = weight1
self.weight2 = weight2
self.file_data = file_data
self.cancel_flag = False
def update_queue_status(message=None):
socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
def process_task(task):
try:
# ファイルデータをPIL Imageに変換
image = Image.open(io.BytesIO(task.file_data))
image = ensure_rgb(image)
# キャンセルチェック
if task.cancel_flag:
return
# 画像処理ロジックを呼び出す
sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
# キャンセルチェック
if task.cancel_flag:
return
socketio.emit('task_complete', {
'task_id': task.task_id,
'sotai_image': sotai_image,
'sketch_image': sketch_image
})
except Exception as e:
if not task.cancel_flag:
socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)})
finally:
if task.task_id in active_tasks:
del active_tasks[task.task_id]
if task.task_id in task_futures:
del task_futures[task.task_id]
update_queue_status('Task completed or cancelled')
def worker():
while True:
try:
task = task_queue.get()
if task.task_id in active_tasks:
future = executor.submit(process_task, task)
task_futures[task.task_id] = future
update_queue_status(f'Task started: {task.task_id}')
except Exception as e:
print(f"Worker error: {str(e)}")
finally:
# Ensure the task is always removed from the queue
task_queue.task_done()
# ワーカースレッドの開始
threading.Thread(target=worker, daemon=True).start()
@app.route('/submit_task', methods=['POST'])
def submit_task():
task_id = str(uuid.uuid4())
file = request.files['file']
mode = request.form.get('mode', 'refine')
weight1 = float(request.form.get('weight1', 0.4))
weight2 = float(request.form.get('weight2', 0.3))
# ファイルデータをバイト列として保存
file_data = file.read()
task = Task(task_id, mode, weight1, weight2, file_data)
task_queue.put(task)
active_tasks[task_id] = task
update_queue_status(f'Task submitted: {task_id}')
queue_size = task_queue.qsize()
return jsonify({'task_id': task_id, 'queue_size': queue_size})
@app.route('/cancel_task/<task_id>', methods=['POST'])
def cancel_task(task_id):
if task_id in active_tasks:
task = active_tasks[task_id]
task.cancel_flag = True
if task_id in task_futures:
task_futures[task_id].cancel()
del task_futures[task_id]
del active_tasks[task_id]
update_queue_status('Task cancelled')
return jsonify({'message': 'Task cancellation requested'})
else:
return jsonify({'message': 'Task not found or already completed'}), 404
def get_active_task_order(task_id):
return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
# get_task_orderイベントハンドラー
@app.route('/get_task_order/<task_id>', methods=['GET'])
def handle_get_task_order(task_id):
task_order = get_active_task_order(task_id)
return jsonify({'task_order': task_order})
@socketio.on('connect')
def handle_connect():
emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
# Flaskルート
@app.route('/', methods=['GET', 'POST'])
def process_refined():
if request.method == 'POST':
file = request.files['file']
weight1 = float(request.form.get('weight1', 0.4))
weight2 = float(request.form.get('weight2', 0.3))
image = ensure_rgb(Image.open(file.stream))
sotai_image, sketch_image = process_image_as_base64(image, "refine", weight1, weight2)
return jsonify({
'sotai_image': sotai_image,
'sketch_image': sketch_image
})
@app.route('/process_original', methods=['GET', 'POST'])
def process_original():
if request.method == 'POST':
file = request.files['file']
image = ensure_rgb(Image.open(file.stream))
sotai_image, sketch_image = process_image_as_base64(image, "original")
return jsonify({
'sotai_image': sotai_image,
'sketch_image': sketch_image
})
@app.route('/process_sketch', methods=['GET', 'POST'])
def process_sketch():
if request.method == 'POST':
file = request.files['file']
image = ensure_rgb(Image.open(file.stream))
sotai_image, sketch_image = process_image_as_base64(image, "sketch")
return jsonify({
'sotai_image': sotai_image,
'sketch_image': sketch_image
})
# エラーハンドラー
@app.errorhandler(500)
def server_error(e):
return jsonify(error=str(e)), 500
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Server options.')
parser.add_argument('--use_local', action='store_true', help='Use local model')
parser.add_argument('--use_gpu', action='store_true', help='Set to True to use GPU but if not available, it will use CPU')
args = parser.parse_args()
initialize(args.use_local, args.use_gpu)
socketio.run(app, debug=True, host='0.0.0.0', port=80) |