Kunpeng Song commited on
Commit
ef3a17c
1 Parent(s): 338f71e
app.py CHANGED
@@ -15,14 +15,14 @@ title = "MoMA"
15
  description = "This model has to run on GPU"
16
  article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
17
 
18
- def MoMA_demo(rgb, mask, subject, prompt):
19
  # move the input and model to GPU for speed if available
20
  with torch.no_grad():
21
- generated_image = model.generate_images(rgb, mask, subject, prompt, strength=1.0, seed=2)
22
  return generated_image
23
 
24
- def inference(rgb, mask, subject, prompt):
25
- result = MoMA_demo(rgb, mask, subject, prompt)
26
  return result
27
 
28
  seed_everything(0)
@@ -40,13 +40,12 @@ model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
40
  gr.Interface(
41
  inference,
42
  [gr.Image(type="pil", label="Input RGB"),
43
- gr.Image(type="pil", label="Input Mask"),
44
  gr.Textbox(lines=1, label="subject"),
45
  gr.Textbox(lines=5, label="Prompt")],
46
  gr.Image(type="pil", label="Output"),
47
  title=title,
48
  description=description,
49
  article=article,
50
- examples=[["example_images/newImages/3.jpg",'example_images/newImages/3_mask.jpg','car','A car in autumn with falling leaves.']],
51
  # enable_queue=True
52
  ).launch(debug=False)
 
15
  description = "This model has to run on GPU"
16
  article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
17
 
18
+ def MoMA_demo(rgb, subject, prompt):
19
  # move the input and model to GPU for speed if available
20
  with torch.no_grad():
21
+ generated_image = model.generate_images(rgb, subject, prompt, strength=1.0, seed=2)
22
  return generated_image
23
 
24
+ def inference(rgb, subject, prompt):
25
+ result = MoMA_demo(rgb, subject, prompt)
26
  return result
27
 
28
  seed_everything(0)
 
40
  gr.Interface(
41
  inference,
42
  [gr.Image(type="pil", label="Input RGB"),
 
43
  gr.Textbox(lines=1, label="subject"),
44
  gr.Textbox(lines=5, label="Prompt")],
45
  gr.Image(type="pil", label="Output"),
46
  title=title,
47
  description=description,
48
  article=article,
49
+ examples=[["example_images/newImages/3.jpg",'car','A car in autumn with falling leaves.']],
50
  # enable_queue=True
51
  ).launch(debug=False)
app_version1.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from torchvision import transforms
6
+ import torch
7
+ from pytorch_lightning import seed_everything
8
+ from torchvision.utils import save_image
9
+ from model_lib.modules import MoMA_main_modal
10
+ from model_lib.utils import parse_args
11
+ import os
12
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
13
+
14
+ title = "MoMA"
15
+ description = "This model has to run on GPU"
16
+ article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
17
+
18
+ def MoMA_demo(rgb, mask, subject, prompt):
19
+ # move the input and model to GPU for speed if available
20
+ with torch.no_grad():
21
+ generated_image = model.generate_images(rgb, mask, subject, prompt, strength=1.0, seed=2)
22
+ return generated_image
23
+
24
+ def inference(rgb, mask, subject, prompt):
25
+ result = MoMA_demo(rgb, mask, subject, prompt)
26
+ return result
27
+
28
+ seed_everything(0)
29
+ args = parse_args()
30
+ #load MoMA from HuggingFace. Auto download
31
+ model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
32
+
33
+
34
+ ################ change texture ##################
35
+ # prompt = "A wooden sculpture of a car on the table."
36
+ # generated_image = model.generate_images(rgb_path, mask_path, subject, prompt, strength=0.4, seed=4, return_mask=True) # set strength to 0.4 for better prompt fidelity
37
+ # save_image(generated_image,f"{args.output_path}/{subject}_{prompt}.jpg")
38
+
39
+
40
+ gr.Interface(
41
+ inference,
42
+ [gr.Image(type="pil", label="Input RGB"),
43
+ gr.Image(type="pil", label="Input Mask"),
44
+ gr.Textbox(lines=1, label="subject"),
45
+ gr.Textbox(lines=5, label="Prompt")],
46
+ gr.Image(type="pil", label="Output"),
47
+ title=title,
48
+ description=description,
49
+ article=article,
50
+ examples=[["example_images/newImages/3.jpg",'example_images/newImages/3_mask.jpg','car','A car in autumn with falling leaves.']],
51
+ # enable_queue=True
52
+ ).launch(debug=False)
dataset_lib/dataset_eval_MoMA.py CHANGED
@@ -3,9 +3,14 @@ import numpy as np
3
  import torch
4
  from torchvision import transforms
5
  from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
 
6
 
 
 
 
 
7
 
8
- def Dataset_evaluate_MoMA(rgb_path, prompt,subject, mask_path, moMA_main_modal):
9
 
10
  LLaVa_processor = moMA_main_modal.image_processor_llava
11
  llava_config = moMA_main_modal.model_llava.config
@@ -14,9 +19,7 @@ def Dataset_evaluate_MoMA(rgb_path, prompt,subject, mask_path, moMA_main_modal):
14
  transforms.Resize((512, 512)),
15
  ])
16
 
17
- rgb_path, prompt,mask_path = rgb_path, prompt,mask_path
18
- image_pil = rgb_path # Image.open(rgb_path)
19
- mask_pil = mask_path # Image.open(mask_path)
20
  blip2_opt = prompt
21
 
22
  if transform is not None:
 
3
  import torch
4
  from torchvision import transforms
5
  from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
6
+ from rembg import remove
7
 
8
+ def create_binary_mask(image):
9
+ grayscale = image.convert("L")
10
+ mask = grayscale.point(lambda x: 255 if x > 1 else 0, '1')
11
+ return mask
12
 
13
+ def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal):
14
 
15
  LLaVa_processor = moMA_main_modal.image_processor_llava
16
  llava_config = moMA_main_modal.model_llava.config
 
19
  transforms.Resize((512, 512)),
20
  ])
21
 
22
+ mask_pil = create_binary_mask(remove(image_pil)) # Image.open(mask_path)
 
 
23
  blip2_opt = prompt
24
 
25
  if transform is not None:
example_images/newImages/3_mask.jpg DELETED
Binary file (7.31 kB)
 
model_lib/modules.py CHANGED
@@ -136,8 +136,8 @@ class MoMA_main_modal(nn.Module):
136
  def reset(self):
137
  self.moMA_generator.reset_all()
138
 
139
- def generate_images(self, rgb_path, mask_path, subject, prompt, strength=1.0, num=1, seed=0):
140
- batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject, mask_path,self)
141
  self.moMA_generator.set_selfAttn_strength(strength)
142
 
143
  with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
 
136
  def reset(self):
137
  self.moMA_generator.reset_all()
138
 
139
+ def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
140
+ batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
141
  self.moMA_generator.set_selfAttn_strength(strength)
142
 
143
  with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):