jiuface commited on
Commit
a4f92f5
1 Parent(s): 7b934fb
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Tuple, Optional
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import random
@@ -9,6 +9,7 @@ from diffusers import FluxInpaintPipeline
9
  import torch
10
  from PIL import Image, ImageFilter
11
  from huggingface_hub import login
 
12
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
13
  import copy
14
  import random
@@ -38,9 +39,6 @@ dtype = torch.bfloat16
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
  base_model = "black-forest-labs/FLUX.1-dev"
40
 
41
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
42
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
43
- pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
44
 
45
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
46
  SAM_IMAGE_MODEL = load_sam_image_model(device=device)
@@ -133,7 +131,7 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
133
  print("upload finish", image_file)
134
  return image_file
135
 
136
-
137
  def run_flux(
138
  image: Image.Image,
139
  mask: Image.Image,
@@ -149,8 +147,13 @@ def run_flux(
149
  ) -> Image.Image:
150
  print("Running FLUX...")
151
 
 
 
 
 
152
  with calculateDuration("load lora"):
153
  print("start to load lora", lora_path, lora_weights)
 
154
  pipe.load_lora_weights(lora_path, weight_name=lora_weights)
155
 
156
  width, height = resolution_wh
@@ -159,7 +162,7 @@ def run_flux(
159
  generator = torch.Generator().manual_seed(seed_slicer)
160
 
161
  with calculateDuration("run pipe"):
162
- genearte_image = PIPE(
163
  prompt=prompt,
164
  image=image,
165
  mask_image=mask,
@@ -170,12 +173,13 @@ def run_flux(
170
  num_inference_steps=num_inference_steps_slider,
171
  max_sequence_length=256,
172
  joint_attention_kwargs={"scale": lora_scale},
 
173
  ).images[0]
174
 
175
  return genearte_image
176
 
177
-
178
- def genearte_mask(image: Image.Image, masking_prompt_text: str) -> Image.Image:
179
  # generate mask by florence & sam
180
  print("Generating mask...")
181
  task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
@@ -186,7 +190,7 @@ def genearte_mask(image: Image.Image, masking_prompt_text: str) -> Image.Image:
186
  model=FLORENCE_MODEL,
187
  processor=FLORENCE_PROCESSOR,
188
  device=device,
189
- image=image,
190
  task=task_prompt,
191
  text=masking_prompt_text
192
  )
@@ -203,7 +207,7 @@ def genearte_mask(image: Image.Image, masking_prompt_text: str) -> Image.Image:
203
 
204
  with calculateDuration("generate segmenet mask"):
205
  # using sam generate segments images
206
- detections = run_sam_inference(SAM_IMAGE_MODEL, image, detections)
207
  if len(detections) == 0:
208
  gr.Info("No objects detected.")
209
  return None
@@ -225,7 +229,7 @@ def genearte_mask(image: Image.Image, masking_prompt_text: str) -> Image.Image:
225
  return images[0]
226
 
227
 
228
- @spaces.GPU(duration=120)
229
  def process(
230
  image_url: str,
231
  inpainting_prompt_text: str,
 
1
  from typing import Tuple, Optional
2
+ import os
3
  import gradio as gr
4
  import numpy as np
5
  import random
 
9
  import torch
10
  from PIL import Image, ImageFilter
11
  from huggingface_hub import login
12
+ from diffusers import AutoencoderTiny, AutoencoderKL
13
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
14
  import copy
15
  import random
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  base_model = "black-forest-labs/FLUX.1-dev"
41
 
 
 
 
42
 
43
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
44
  SAM_IMAGE_MODEL = load_sam_image_model(device=device)
 
131
  print("upload finish", image_file)
132
  return image_file
133
 
134
+ @spaces.GPU(duration=60)
135
  def run_flux(
136
  image: Image.Image,
137
  mask: Image.Image,
 
147
  ) -> Image.Image:
148
  print("Running FLUX...")
149
 
150
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
151
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
152
+ pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
153
+
154
  with calculateDuration("load lora"):
155
  print("start to load lora", lora_path, lora_weights)
156
+ pipe.unload_lora_weights()
157
  pipe.load_lora_weights(lora_path, weight_name=lora_weights)
158
 
159
  width, height = resolution_wh
 
162
  generator = torch.Generator().manual_seed(seed_slicer)
163
 
164
  with calculateDuration("run pipe"):
165
+ genearte_image = pipe(
166
  prompt=prompt,
167
  image=image,
168
  mask_image=mask,
 
173
  num_inference_steps=num_inference_steps_slider,
174
  max_sequence_length=256,
175
  joint_attention_kwargs={"scale": lora_scale},
176
+ good_vae=good_vae
177
  ).images[0]
178
 
179
  return genearte_image
180
 
181
+ @spaces.GPU(duration=10)
182
+ def genearte_mask(image_input: Image.Image, masking_prompt_text: str) -> Image.Image:
183
  # generate mask by florence & sam
184
  print("Generating mask...")
185
  task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
 
190
  model=FLORENCE_MODEL,
191
  processor=FLORENCE_PROCESSOR,
192
  device=device,
193
+ image=image_input,
194
  task=task_prompt,
195
  text=masking_prompt_text
196
  )
 
207
 
208
  with calculateDuration("generate segmenet mask"):
209
  # using sam generate segments images
210
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
211
  if len(detections) == 0:
212
  gr.Info("No objects detected.")
213
  return None
 
229
  return images[0]
230
 
231
 
232
+
233
  def process(
234
  image_url: str,
235
  inpainting_prompt_text: str,
requirements.txt CHANGED
@@ -14,4 +14,5 @@ opencv-python
14
  pytest
15
  requests
16
  git+https://github.com/Gothos/diffusers.git@flux-inpaint
17
- boto3
 
 
14
  pytest
15
  requests
16
  git+https://github.com/Gothos/diffusers.git@flux-inpaint
17
+ boto3
18
+ sentencepiece
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (125 Bytes). View file
 
utils/__pycache__/florence.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
utils/__pycache__/sam.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
utils/florence.py CHANGED
@@ -29,10 +29,8 @@ def load_florence_model(
29
  device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
30
  ) -> Tuple[Any, Any]:
31
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
32
- model = AutoModelForCausalLM.from_pretrained(
33
- checkpoint, trust_remote_code=True).to(device).eval()
34
- processor = AutoProcessor.from_pretrained(
35
- checkpoint, trust_remote_code=True)
36
  return model, processor
37
 
38
 
@@ -49,16 +47,8 @@ def run_florence_inference(
49
  else:
50
  prompt = task
51
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
52
- print(inputs)
53
- generated_ids = model.generate(
54
- input_ids=inputs["input_ids"],
55
- pixel_values=inputs["pixel_values"],
56
- max_new_tokens=1024,
57
- num_beams=3
58
- )
59
- generated_text = processor.batch_decode(
60
- generated_ids, skip_special_tokens=False)[0]
61
- response = processor.post_process_generation(
62
- generated_text, task=task, image_size=image.size)
63
- print(generated_text, response)
64
- return generated_text, response
 
29
  device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
30
  ) -> Tuple[Any, Any]:
31
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
32
+ model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True).to(device).eval()
33
+ processor = AutoProcessor.from_pretrained(checkpoint, trust_remote_code=True)
 
 
34
  return model, processor
35
 
36
 
 
47
  else:
48
  prompt = task
49
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
50
+ generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3)
51
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
52
+ response = processor.post_process_generation(generated_text, task=task, image_size=image.size)
53
+ print("run_florence_inference", "finish", generated_text, response)
54
+ return generated_text, response