File size: 4,350 Bytes
1f6f05e
46c61dc
435f67e
 
 
 
 
 
46c61dc
 
 
435f67e
 
 
286f519
 
 
 
61fd9f3
435f67e
46c61dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435f67e
46c61dc
435f67e
 
 
 
 
 
 
 
 
46c61dc
 
 
 
 
 
435f67e
 
 
 
46c61dc
 
 
 
 
 
435f67e
700fd72
 
 
 
 
74d0aae
700fd72
 
 
 
 
1fb1361
700fd72
 
435f67e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b19ef1d
46c61dc
435f67e
46c61dc
700fd72
435f67e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import hf_hub_download
from flask import Flask, request, jsonify, send_file, render_template_string
import requests
import io
import os
import random
from PIL import Image
from deep_translator import GoogleTranslator
import torch
from transformers import pipeline, hf_hub_download
from black_forest_labs import FluxPipeline  # FluxPipelineのクラスはインポート元のモジュールに応じて変更

app = Flask(__name__)

API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"

API_TOKEN = os.getenv("HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
timeout = 50000  # タイムアウトを300秒に設定

# 初期化したFluxPipelineのオブジェクト
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to(device="cuda", dtype=torch.bfloat16)

# Function to load LoRA weights dynamically
def load_lora_weights(pipe, lora_model_name, lora_scale=0.125):
    try:
        # Hugging Face Hubからの動的ダウンロード
        lora_weights_path = hf_hub_download(lora_model_name, filename="Hyper-FLUX.1-dev-8steps-lora.safetensors")
        pipe.load_lora_weights(lora_weights_path)
        pipe.fuse_lora(lora_scale=lora_scale)
        print(f"Loaded LoRA weights from {lora_model_name} with scale {lora_scale}")
    except Exception as e:
        print(f"Error loading LoRA weights: {e}")
        return f"Error loading LoRA weights: {e}"
    return None

# Function to query the API and return the generated image
def query(prompt, negative_prompt="", steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, lora=None, lora_scale=0.125):
    if not prompt:
        return None, "Prompt is required"

    key = random.randint(0, 999)
    
    # Translate the prompt from Russian to English if necessary
    prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
    print(f'Generation {key} translation: {prompt}')

    # LoRAが設定されている場合、ロードして適用
    if lora:
        error = load_lora_weights(pipe, lora, lora_scale)
        if error:
            return None, error

    # Add some extra flair to the prompt
    prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
    print(f'Generation {key}: {prompt}')
    
    # LoRA適用後、画像を生成
    try:
        image = pipe(prompt, negative_prompt=negative_prompt, steps=steps, cfg_scale=cfg_scale, seed=seed, width=width, height=height, strength=strength).images[0]
        return image, None
    except Exception as e:
        return None, f"Error during image generation: {e}"

# Content-Security-Policyヘッダーを設定するための関数
@app.after_request
def add_security_headers(response):
    response.headers['Content-Security-Policy'] = (
        "default-src 'self'; "
        "connect-src 'self' ^https?:\/\/[\w.-]+\.[\w.-]+(\/[\w.-]*)*(\?[^\s]*)?$"
        "img-src 'self' data:; "
        "style-src 'self' 'unsafe-inline'; "
        "script-src 'self' 'unsafe-inline'; "
    )
    return response

# HTML template for the index page
index_html = """
"""

@app.route('/')
def index():
    return render_template_string(index_html)

@app.route('/generate', methods=['GET'])
def generate_image():
    prompt = request.args.get("prompt", "")
    negative_prompt = request.args.get("negative_prompt", "")
    steps = int(request.args.get("steps", 35))
    cfg_scale = float(request.args.get("cfgs", 7))
    sampler = request.args.get("sampler", "DPM++ 2M Karras")
    strength = float(request.args.get("strength", 0.7))
    seed = int(request.args.get("seed", -1))
    width = int(request.args.get("width", 1024))
    height = int(request.args.get("height", 1024))
    lora = request.args.get("lora", None)  # loraパラメータの取得
    lora_scale = float(request.args.get("lora_scale", 0.125))  # lora_scaleパラメータの取得

    image, error = query(prompt, negative_prompt, steps, cfg_scale, sampler, seed, strength, width, height, lora, lora_scale)

    if error:
        return jsonify({"error": error}), 400

    img_bytes = io.BytesIO()
    image.save(img_bytes, format='PNG')
    img_bytes.seek(0)
    return send_file(img_bytes, mimetype='image/png')

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)