Kunpeng Song commited on
Commit
a6600be
1 Parent(s): 5b2734d
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +6 -2
  3. dataset_lib/dataset_eval_MoMA.py +2 -2
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import spaces
2
-
3
  import gradio as gr
4
  import torch
5
  import numpy as np
6
  import torch
7
  from pytorch_lightning import seed_everything
8
  from model_lib.utils import parse_args
9
- import os
 
10
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
11
 
12
  title = "MoMA"
@@ -17,6 +18,9 @@ args = parse_args()
17
 
18
  model = None
19
 
 
 
 
20
  @spaces.GPU
21
  def inference(rgb, subject, prompt, strength, seed):
22
  seed = int(seed) if seed else 0
 
1
  import spaces
2
+ import os
3
  import gradio as gr
4
  import torch
5
  import numpy as np
6
  import torch
7
  from pytorch_lightning import seed_everything
8
  from model_lib.utils import parse_args
9
+ from llava.mm_utils import process_image
10
+
11
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
12
 
13
  title = "MoMA"
 
18
 
19
  model = None
20
 
21
+ def my_process_image(a, b, c):
22
+ return process_image(a, b, c)
23
+
24
  @spaces.GPU
25
  def inference(rgb, subject, prompt, strength, seed):
26
  seed = int(seed) if seed else 0
dataset_lib/dataset_eval_MoMA.py CHANGED
@@ -2,7 +2,7 @@ from PIL import Image
2
  import numpy as np
3
  import torch
4
  from torchvision import transforms
5
- from llava.mm_utils import process_images
6
  from rembg import remove
7
 
8
  def create_binary_mask(image):
@@ -38,7 +38,7 @@ def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal):
38
  image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
39
  image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
40
 
41
- res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config)
42
  res['label'] = [subject]
43
  return res
44
 
 
2
  import numpy as np
3
  import torch
4
  from torchvision import transforms
5
+ from ..app import my_process_image
6
  from rembg import remove
7
 
8
  def create_binary_mask(image):
 
38
  image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
39
  image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
40
 
41
+ res['llava_processed'] = my_process_image([image_pil], LLaVa_processor, llava_config)
42
  res['label'] = [subject]
43
  return res
44