FLUX-1-dev-lora / app.py
soiz's picture
Update app.py
1f6f05e verified
import hf_hub_download
from flask import Flask, request, jsonify, send_file, render_template_string
import requests
import io
import os
import random
from PIL import Image
from deep_translator import GoogleTranslator
import torch
from transformers import pipeline, hf_hub_download
from black_forest_labs import FluxPipeline # FluxPipelineのクラスはインポート元のモジュールに応じて変更
app = Flask(__name__)
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
API_TOKEN = os.getenv("HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
timeout = 50000 # タイムアウトを300秒に設定
# 初期化したFluxPipelineのオブジェクト
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to(device="cuda", dtype=torch.bfloat16)
# Function to load LoRA weights dynamically
def load_lora_weights(pipe, lora_model_name, lora_scale=0.125):
try:
# Hugging Face Hubからの動的ダウンロード
lora_weights_path = hf_hub_download(lora_model_name, filename="Hyper-FLUX.1-dev-8steps-lora.safetensors")
pipe.load_lora_weights(lora_weights_path)
pipe.fuse_lora(lora_scale=lora_scale)
print(f"Loaded LoRA weights from {lora_model_name} with scale {lora_scale}")
except Exception as e:
print(f"Error loading LoRA weights: {e}")
return f"Error loading LoRA weights: {e}"
return None
# Function to query the API and return the generated image
def query(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, lora=None, lora_scale=0.125):
if not prompt:
return None, "Prompt is required"
key = random.randint(0, 999)
# Translate the prompt from Russian to English if necessary
prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
print(f'Generation {key} translation: {prompt}')
# LoRAが設定されている場合、ロードして適用
if lora:
error = load_lora_weights(pipe, lora, lora_scale)
if error:
return None, error
# Add some extra flair to the prompt
prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
print(f'Generation {key}: {prompt}')
# LoRA適用後、画像を生成
try:
image = pipe(prompt, negative_prompt=negative_prompt, steps=steps, cfg_scale=cfg_scale, seed=seed, width=width, height=height, strength=strength).images[0]
return image, None
except Exception as e:
return None, f"Error during image generation: {e}"
# Content-Security-Policyヘッダーを設定するための関数
@app.after_request
def add_security_headers(response):
response.headers['Content-Security-Policy'] = (
"default-src 'self'; "
"connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
"img-src 'self' data:; "
"style-src 'self' 'unsafe-inline'; "
"script-src 'self' 'unsafe-inline'; "
)
return response
# HTML template for the index page
index_html = """
"""
@app.route('/')
def index():
return render_template_string(index_html)
@app.route('/generate', methods=['GET'])
def generate_image():
prompt = request.args.get("prompt", "")
negative_prompt = request.args.get("negative_prompt", "")
steps = int(request.args.get("steps", 35))
cfg_scale = float(request.args.get("cfgs", 7))
sampler = request.args.get("sampler", "DPM++ 2M Karras")
strength = float(request.args.get("strength", 0.7))
seed = int(request.args.get("seed", -1))
width = int(request.args.get("width", 1024))
height = int(request.args.get("height", 1024))
lora = request.args.get("lora", None) # loraパラメータの取得
lora_scale = float(request.args.get("lora_scale", 0.125)) # lora_scaleパラメータの取得
image, error = query(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height, lora, lora_scale)
if error:
return jsonify({"error": error}), 400
img_bytes = io.BytesIO()
image.save(img_bytes, format='PNG')
img_bytes.seek(0)
return send_file(img_bytes, mimetype='image/png')
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)