gokaygokay commited on
Commit
0b30521
1 Parent(s): 45523be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -2
app.py CHANGED
@@ -5,9 +5,43 @@ import os
5
  import re
6
  from datetime import datetime
7
  from huggingface_hub import InferenceClient
 
 
 
 
8
 
 
9
 
10
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Load JSON files
12
  def load_json_file(file_name):
13
  file_path = os.path.join("data", file_name)
@@ -102,7 +136,7 @@ class PromptGenerator:
102
 
103
  def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
104
  additional_details, photography_styles, device, photographer, artist, digital_artform,
105
- place, lighting, clothing, composition, pose, background):
106
  kwargs = locals()
107
  del kwargs['self']
108
 
@@ -254,6 +288,10 @@ class PromptGenerator:
254
  components.append(f"by {self.get_choice(kwargs.get('artist', ''), ARTIST)}")
255
  components.append("BREAK_CLIPL")
256
 
 
 
 
 
257
  prompt = " ".join(components)
258
  prompt = re.sub(" +", " ", prompt)
259
  replaced = prompt.replace("of as", "of")
@@ -367,6 +405,7 @@ def create_interface():
367
  pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="random")
368
  background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="random")
369
  with gr.Column():
 
370
  generate_button = gr.Button("Generate Prompt")
371
  output = gr.Textbox(label="Generated Prompt / Input Text", lines=5)
372
  t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
@@ -389,7 +428,7 @@ def create_interface():
389
  prompt_generator.generate_prompt,
390
  inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
391
  additional_details, photography_styles, device, photographer, artist, digital_artform,
392
- place, lighting, clothing, composition, pose, background],
393
  outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output]
394
  )
395
 
 
5
  import re
6
  from datetime import datetime
7
  from huggingface_hub import InferenceClient
8
+ import subprocess
9
+ import torch
10
+ from PIL import Image
11
+ from transformers import AutoProcessor, AutoModelForCausalLM
12
 
13
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
14
 
15
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
16
+
17
+
18
+ # Initialize Florence model
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
21
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
22
+
23
+ # Florence caption function
24
+ def florence_caption(image):
25
+ if not isinstance(image, Image.Image):
26
+ image = Image.fromarray(image)
27
+
28
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
29
+ generated_ids = florence_model.generate(
30
+ input_ids=inputs["input_ids"],
31
+ pixel_values=inputs["pixel_values"],
32
+ max_new_tokens=1024,
33
+ early_stopping=False,
34
+ do_sample=False,
35
+ num_beams=3,
36
+ )
37
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
38
+ parsed_answer = florence_processor.post_process_generation(
39
+ generated_text,
40
+ task="<MORE_DETAILED_CAPTION>",
41
+ image_size=(image.width, image.height)
42
+ )
43
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
44
+
45
  # Load JSON files
46
  def load_json_file(file_name):
47
  file_path = os.path.join("data", file_name)
 
136
 
137
  def generate_prompt(self, seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
138
  additional_details, photography_styles, device, photographer, artist, digital_artform,
139
+ place, lighting, clothing, composition, pose, background, input_image):
140
  kwargs = locals()
141
  del kwargs['self']
142
 
 
288
  components.append(f"by {self.get_choice(kwargs.get('artist', ''), ARTIST)}")
289
  components.append("BREAK_CLIPL")
290
 
291
+ if input_image is not None:
292
+ caption = florence_caption(input_image)
293
+ components.append(f" {caption}")
294
+
295
  prompt = " ".join(components)
296
  prompt = re.sub(" +", " ", prompt)
297
  replaced = prompt.replace("of as", "of")
 
405
  pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="random")
406
  background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="random")
407
  with gr.Column():
408
+ input_image = gr.Image(label="Input Image (optional)")
409
  generate_button = gr.Button("Generate Prompt")
410
  output = gr.Textbox(label="Generated Prompt / Input Text", lines=5)
411
  t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
 
428
  prompt_generator.generate_prompt,
429
  inputs=[seed, custom, subject, artform, photo_type, body_types, default_tags, roles, hairstyles,
430
  additional_details, photography_styles, device, photographer, artist, digital_artform,
431
+ place, lighting, clothing, composition, pose, background, input_image],
432
  outputs=[output, gr.Number(visible=False), t5xxl_output, clip_l_output, clip_g_output]
433
  )
434