gokaygokay
commited on
Commit
•
0b30521
1
Parent(s):
45523be
Update app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,43 @@ import os
|
|
5 |
import re
|
6 |
from datetime import datetime
|
7 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
8 |
|
|
|
9 |
|
10 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# Load JSON files
|
12 |
def load_json_file(file_name):
|
13 |
file_path = os.path.join("data", file_name)
|
@@ -102,7 +136,7 @@ class PromptGenerator:
|
|
102 |
|
103 |
def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
|
104 |
additional_details, photography_styles, device, photographer, artist, digital_artform,
|
105 |
-
place, lighting, clothing, composition, pose, background):
|
106 |
kwargs = locals()
|
107 |
del kwargs['self']
|
108 |
|
@@ -254,6 +288,10 @@ class PromptGenerator:
|
|
254 |
components.append(f"by {self.get_choice(kwargs.get('artist', ''), ARTIST)}")
|
255 |
components.append("BREAK_CLIPL")
|
256 |
|
|
|
|
|
|
|
|
|
257 |
prompt = " ".join(components)
|
258 |
prompt = re.sub(" +", " ", prompt)
|
259 |
replaced = prompt.replace("of as", "of")
|
@@ -367,6 +405,7 @@ def create_interface():
|
|
367 |
pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="random")
|
368 |
background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="random")
|
369 |
with gr.Column():
|
|
|
370 |
generate_button = gr.Button("Generate Prompt")
|
371 |
output = gr.Textbox(label="Generated Prompt / Input Text", lines=5)
|
372 |
t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
|
@@ -389,7 +428,7 @@ def create_interface():
|
|
389 |
prompt_generator.generate_prompt,
|
390 |
inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
|
391 |
additional_details, photography_styles, device, photographer, artist, digital_artform,
|
392 |
-
place, lighting, clothing, composition, pose, background],
|
393 |
outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output]
|
394 |
)
|
395 |
|
|
|
5 |
import re
|
6 |
from datetime import datetime
|
7 |
from huggingface_hub import InferenceClient
|
8 |
+
import subprocess
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
12 |
|
13 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
14 |
|
15 |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
|
16 |
+
|
17 |
+
|
18 |
+
# Initialize Florence model
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
|
21 |
+
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
|
22 |
+
|
23 |
+
# Florence caption function
|
24 |
+
def florence_caption(image):
|
25 |
+
if not isinstance(image, Image.Image):
|
26 |
+
image = Image.fromarray(image)
|
27 |
+
|
28 |
+
inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
|
29 |
+
generated_ids = florence_model.generate(
|
30 |
+
input_ids=inputs["input_ids"],
|
31 |
+
pixel_values=inputs["pixel_values"],
|
32 |
+
max_new_tokens=1024,
|
33 |
+
early_stopping=False,
|
34 |
+
do_sample=False,
|
35 |
+
num_beams=3,
|
36 |
+
)
|
37 |
+
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
38 |
+
parsed_answer = florence_processor.post_process_generation(
|
39 |
+
generated_text,
|
40 |
+
task="<MORE_DETAILED_CAPTION>",
|
41 |
+
image_size=(image.width, image.height)
|
42 |
+
)
|
43 |
+
return parsed_answer["<MORE_DETAILED_CAPTION>"]
|
44 |
+
|
45 |
# Load JSON files
|
46 |
def load_json_file(file_name):
|
47 |
file_path = os.path.join("data", file_name)
|
|
|
136 |
|
137 |
def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
|
138 |
additional_details, photography_styles, device, photographer, artist, digital_artform,
|
139 |
+
place, lighting, clothing, composition, pose, background, input_image):
|
140 |
kwargs = locals()
|
141 |
del kwargs['self']
|
142 |
|
|
|
288 |
components.append(f"by {self.get_choice(kwargs.get('artist', ''), ARTIST)}")
|
289 |
components.append("BREAK_CLIPL")
|
290 |
|
291 |
+
if input_image is not None:
|
292 |
+
caption = florence_caption(input_image)
|
293 |
+
components.append(f" {caption}")
|
294 |
+
|
295 |
prompt = " ".join(components)
|
296 |
prompt = re.sub(" +", " ", prompt)
|
297 |
replaced = prompt.replace("of as", "of")
|
|
|
405 |
pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="random")
|
406 |
background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="random")
|
407 |
with gr.Column():
|
408 |
+
input_image = gr.Image(label="Input Image (optional)")
|
409 |
generate_button = gr.Button("Generate Prompt")
|
410 |
output = gr.Textbox(label="Generated Prompt / Input Text", lines=5)
|
411 |
t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
|
|
|
428 |
prompt_generator.generate_prompt,
|
429 |
inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
|
430 |
additional_details, photography_styles, device, photographer, artist, digital_artform,
|
431 |
+
place, lighting, clothing, composition, pose, background, input_image],
|
432 |
outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output]
|
433 |
)
|
434 |
|