Spaces:
Paused
Paused
Kunpeng Song
commited on
Commit
β’
8a4a948
1
Parent(s):
6359e9a
recreate
Browse files- README.md +6 -4
- app.py +52 -0
- checkpoints/.DS_Store +0 -0
- checkpoints/ckpt_saving_path.txt +0 -0
- dataset_lib/__pycache__/dataset_eval_MoMA.cpython-310.pyc +0 -0
- dataset_lib/dataset_eval_MoMA.py +41 -0
- example_images/newImages/.DS_Store +0 -0
- example_images/newImages/3.jpg +0 -0
- example_images/newImages/3_mask.jpg +0 -0
- model_lib/__init__.py +0 -0
- model_lib/__pycache__/__init__.cpython-310.pyc +0 -0
- model_lib/__pycache__/__init__.cpython-39.pyc +0 -0
- model_lib/__pycache__/attention_processor.cpython-310.pyc +0 -0
- model_lib/__pycache__/moMA_generator.cpython-310.pyc +0 -0
- model_lib/__pycache__/moMA_generator.cpython-39.pyc +0 -0
- model_lib/__pycache__/modules.cpython-310.pyc +0 -0
- model_lib/__pycache__/modules.cpython-39.pyc +0 -0
- model_lib/__pycache__/utils.cpython-310.pyc +0 -0
- model_lib/attention_processor.py +245 -0
- model_lib/moMA_generator.py +285 -0
- model_lib/modules.py +151 -0
- model_lib/utils.py +27 -0
- output/car_A car in autumn with falling leaves..jpg +0 -0
- output/car_A wooden sculpture of a car on the table..jpg +0 -0
- requirements.txt +32 -0
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
---
|
2 |
-
title: MoMA
|
3 |
emoji: π
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.4
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: MoMA
|
3 |
emoji: π
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.4
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: Multi-modal LLM for image personalization
|
12 |
---
|
13 |
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torchvision import transforms
|
6 |
+
import torch
|
7 |
+
from pytorch_lightning import seed_everything
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
from model_lib.modules import MoMA_main_modal
|
10 |
+
from model_lib.utils import parse_args
|
11 |
+
import os
|
12 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
13 |
+
|
14 |
+
title = "MoMA"
|
15 |
+
description = "This model has to run on GPU"
|
16 |
+
article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
|
17 |
+
|
18 |
+
def MoMA_demo(rgb, mask, subject, prompt):
|
19 |
+
# move the input and model to GPU for speed if available
|
20 |
+
with torch.no_grad():
|
21 |
+
generated_image = model.generate_images(rgb, mask, subject, prompt, strength=1.0, seed=2)
|
22 |
+
return generated_image
|
23 |
+
|
24 |
+
def inference(rgb, mask, subject, prompt):
|
25 |
+
result = MoMA_demo(rgb, mask, subject, prompt)
|
26 |
+
return result
|
27 |
+
|
28 |
+
seed_everything(0)
|
29 |
+
args = parse_args()
|
30 |
+
#load MoMA from HuggingFace. Auto download
|
31 |
+
model = MoMA_main_modal(args).to(args.device, dtype=torch.bfloat16)
|
32 |
+
|
33 |
+
|
34 |
+
################ change texture ##################
|
35 |
+
# prompt = "A wooden sculpture of a car on the table."
|
36 |
+
# generated_image = model.generate_images(rgb_path, mask_path, subject, prompt, strength=0.4, seed=4, return_mask=True) # set strength to 0.4 for better prompt fidelity
|
37 |
+
# save_image(generated_image,f"{args.output_path}/{subject}_{prompt}.jpg")
|
38 |
+
|
39 |
+
|
40 |
+
gr.Interface(
|
41 |
+
inference,
|
42 |
+
[gr.Image(type="pil", label="Input RGB"),
|
43 |
+
gr.Image(type="pil", label="Input Mask"),
|
44 |
+
gr.Textbox(lines=1, label="subject"),
|
45 |
+
gr.Textbox(lines=5, label="Prompt")],
|
46 |
+
gr.Image(type="pil", label="Output"),
|
47 |
+
title=title,
|
48 |
+
description=description,
|
49 |
+
article=article,
|
50 |
+
examples=[["example_images/newImages/3.jpg",'example_images/newImages/3_mask.jpg','car','A car in autumn with falling leaves.']],
|
51 |
+
# enable_queue=True
|
52 |
+
).launch(debug=False)
|
checkpoints/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
checkpoints/ckpt_saving_path.txt
ADDED
File without changes
|
dataset_lib/__pycache__/dataset_eval_MoMA.cpython-310.pyc
ADDED
Binary file (1.43 kB). View file
|
|
dataset_lib/dataset_eval_MoMA.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
6 |
+
|
7 |
+
|
8 |
+
def Dataset_evaluate_MoMA(rgb_path, prompt,subject, mask_path, moMA_main_modal):
|
9 |
+
|
10 |
+
LLaVa_processor = moMA_main_modal.image_processor_llava
|
11 |
+
llava_config = moMA_main_modal.model_llava.config
|
12 |
+
|
13 |
+
transform = transforms.Compose([
|
14 |
+
transforms.Resize((512, 512)),
|
15 |
+
])
|
16 |
+
|
17 |
+
rgb_path, prompt,mask_path = rgb_path, prompt,mask_path
|
18 |
+
image_pil = rgb_path # Image.open(rgb_path)
|
19 |
+
mask_pil = mask_path # Image.open(mask_path)
|
20 |
+
blip2_opt = prompt
|
21 |
+
|
22 |
+
if transform is not None:
|
23 |
+
image_pil = transform(image_pil)
|
24 |
+
mask_pil = transform(mask_pil)
|
25 |
+
|
26 |
+
mask_pil = np.array(mask_pil)
|
27 |
+
mask_pil = mask_pil[:,:,0] if len(mask_pil.shape)==3 else mask_pil
|
28 |
+
image = torch.from_numpy(np.array(image_pil)).permute(2,0,1)
|
29 |
+
mask = (torch.clamp((torch.from_numpy(mask_pil).unsqueeze(0)).float(),min=0.0,max=1.0)>0).float()
|
30 |
+
|
31 |
+
res = {'image': (image/127.5-1).unsqueeze(0),\
|
32 |
+
'mask': mask.unsqueeze(0), \
|
33 |
+
'text': [blip2_opt]}
|
34 |
+
|
35 |
+
image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
|
36 |
+
image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
|
37 |
+
|
38 |
+
res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config)
|
39 |
+
res['label'] = [subject]
|
40 |
+
return res
|
41 |
+
|
example_images/newImages/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
example_images/newImages/3.jpg
ADDED
example_images/newImages/3_mask.jpg
ADDED
model_lib/__init__.py
ADDED
File without changes
|
model_lib/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (198 Bytes). View file
|
|
model_lib/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (196 Bytes). View file
|
|
model_lib/__pycache__/attention_processor.cpython-310.pyc
ADDED
Binary file (7.07 kB). View file
|
|
model_lib/__pycache__/moMA_generator.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
model_lib/__pycache__/moMA_generator.cpython-39.pyc
ADDED
Binary file (10 kB). View file
|
|
model_lib/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (6.98 kB). View file
|
|
model_lib/__pycache__/modules.cpython-39.pyc
ADDED
Binary file (7.5 kB). View file
|
|
model_lib/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.46 kB). View file
|
|
model_lib/attention_processor.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
import math
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
import torchvision.transforms as T
|
9 |
+
|
10 |
+
def get_mask_from_cross(attn_processors):
|
11 |
+
reference_masks = []
|
12 |
+
for attn_processor in attn_processors.values():
|
13 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
14 |
+
reference_masks.append(attn_processor.mask_i)
|
15 |
+
mask = torch.cat(reference_masks,dim=1).mean(dim=1)
|
16 |
+
mask = (mask-mask.min())/(mask.max()-mask.min())
|
17 |
+
mask = (mask>0.2).to(torch.float32)*mask
|
18 |
+
mask = (mask-mask.min())/(mask.max()-mask.min())
|
19 |
+
return mask.unsqueeze(1)
|
20 |
+
|
21 |
+
class IPAttnProcessor(nn.Module):
|
22 |
+
r"""
|
23 |
+
Attention processor for IP-Adapater.
|
24 |
+
Args:
|
25 |
+
hidden_size (`int`):
|
26 |
+
The hidden size of the attention layer.
|
27 |
+
cross_attention_dim (`int`):
|
28 |
+
The number of channels in the `encoder_hidden_states`.
|
29 |
+
scale (`float`, defaults to 1.0):
|
30 |
+
the weight scale of image prompt.
|
31 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
32 |
+
The context length of the image features.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.hidden_size = hidden_size
|
39 |
+
self.cross_attention_dim = cross_attention_dim
|
40 |
+
self.scale = scale
|
41 |
+
self.num_tokens = num_tokens
|
42 |
+
|
43 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
44 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
45 |
+
|
46 |
+
self.store_attn = None
|
47 |
+
self.enabled = True
|
48 |
+
self.mode = 'inject'
|
49 |
+
|
50 |
+
self.subject_idxs = None
|
51 |
+
self.mask_i = None
|
52 |
+
self.mask_ig_prev = None
|
53 |
+
|
54 |
+
def __call__(
|
55 |
+
self,
|
56 |
+
attn,
|
57 |
+
hidden_states,
|
58 |
+
encoder_hidden_states=None,
|
59 |
+
attention_mask=None,
|
60 |
+
temb=None,
|
61 |
+
):
|
62 |
+
residual = hidden_states
|
63 |
+
|
64 |
+
input_ndim = hidden_states.ndim
|
65 |
+
|
66 |
+
if input_ndim == 4:
|
67 |
+
batch_size, channel, height, width = hidden_states.shape
|
68 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
69 |
+
|
70 |
+
batch_size, sequence_length, _ = (
|
71 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
72 |
+
)
|
73 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
74 |
+
|
75 |
+
if attn.group_norm is not None:
|
76 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
77 |
+
|
78 |
+
query = attn.to_q(hidden_states)
|
79 |
+
|
80 |
+
if encoder_hidden_states is None:
|
81 |
+
encoder_hidden_states = hidden_states
|
82 |
+
else:
|
83 |
+
# get encoder_hidden_states, ip_hidden_states
|
84 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
85 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
|
86 |
+
if attn.norm_cross:
|
87 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
88 |
+
|
89 |
+
key = attn.to_k(encoder_hidden_states)
|
90 |
+
value = attn.to_v(encoder_hidden_states)
|
91 |
+
|
92 |
+
query = attn.head_to_batch_dim(query)
|
93 |
+
key = attn.head_to_batch_dim(key)
|
94 |
+
value = attn.head_to_batch_dim(value)
|
95 |
+
|
96 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
97 |
+
hidden_states = torch.bmm(attention_probs, value)
|
98 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
99 |
+
|
100 |
+
# for ip-adapter
|
101 |
+
if self.enabled:
|
102 |
+
if self.mode == 'inject' or self.mode == 'masked_generation':
|
103 |
+
ip_key = self.to_k_ip(ip_hidden_states.to(torch.float16))
|
104 |
+
ip_value = self.to_v_ip(ip_hidden_states.to(torch.float16))
|
105 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
106 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
107 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key.to(torch.float32), None)
|
108 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value.to(torch.float32))
|
109 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
110 |
+
if (self.mask_ig_prev is not None) and self.mode == 'masked_generation':
|
111 |
+
mask_ig_prev = rearrange(F.interpolate(self.mask_ig_prev,size=int(math.sqrt(query.shape[1]))),"b c h w -> b (h w) c")
|
112 |
+
if not mask_ig_prev.shape[0]==ip_hidden_states.shape[0]: mask_ig_prev = mask_ig_prev.repeat(2,1,1)
|
113 |
+
ip_hidden_states = ip_hidden_states * mask_ig_prev
|
114 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
115 |
+
if self.mode == 'extract' or self.mode == 'masked_generation':
|
116 |
+
subject_idxs = self.subject_idxs*2 if not (hidden_states.shape[0] == len(self.subject_idxs)) else self.subject_idxs
|
117 |
+
assert (hidden_states.shape[0] == len(subject_idxs))
|
118 |
+
attentions = rearrange(attention_probs, '(b h) n d -> b h n d', h=8).mean(1)
|
119 |
+
attn_extracted = [attentions[i, :, subject_idxs[i]].sum(-1) for i in range(hidden_states.shape[0])]
|
120 |
+
attn_extracted = [(atn-atn.min())/(atn.max()-atn.min()) for atn in attn_extracted]
|
121 |
+
attn_extracted = torch.stack(attn_extracted, dim=0)
|
122 |
+
attn_extracted = rearrange(attn_extracted, 'b (h w) -> b h w', h=int(math.sqrt(attention_probs.shape[1])))
|
123 |
+
attn_extracted = torch.clamp(F.interpolate(attn_extracted.unsqueeze(1),size=512),min=0,max=1)
|
124 |
+
self.mask_i = attn_extracted
|
125 |
+
|
126 |
+
# linear proj
|
127 |
+
hidden_states = attn.to_out[0](hidden_states)
|
128 |
+
# dropout
|
129 |
+
hidden_states = attn.to_out[1](hidden_states)
|
130 |
+
|
131 |
+
if input_ndim == 4:
|
132 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
133 |
+
|
134 |
+
return hidden_states
|
135 |
+
|
136 |
+
### added for self attention
|
137 |
+
class IPAttnProcessor_Self(nn.Module):
|
138 |
+
r"""
|
139 |
+
Attention processor for IP-Adapater. (But for self attention)
|
140 |
+
Args:
|
141 |
+
hidden_size (`int`):
|
142 |
+
The hidden size of the attention layer.
|
143 |
+
cross_attention_dim (`int`):
|
144 |
+
The number of channels in the `encoder_hidden_states`.
|
145 |
+
scale (`float`, defaults to 1.0):
|
146 |
+
the weight scale of image prompt.
|
147 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
148 |
+
The context length of the image features.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.hidden_size = hidden_size
|
155 |
+
self.cross_attention_dim = cross_attention_dim
|
156 |
+
self.scale = scale
|
157 |
+
self.num_tokens = num_tokens
|
158 |
+
|
159 |
+
self.to_k_ip = nn.Linear(hidden_size, hidden_size, bias=False)
|
160 |
+
self.to_v_ip = nn.Linear(hidden_size, hidden_size, bias=False)
|
161 |
+
|
162 |
+
self.scale_learnable = torch.nn.Parameter(torch.zeros(1),requires_grad=True)
|
163 |
+
|
164 |
+
self.enabled = True
|
165 |
+
self.mode = 'extract'
|
166 |
+
|
167 |
+
self.store_ks, self.store_vs = [], []
|
168 |
+
self.mask_id, self.mask_ig = None, None
|
169 |
+
|
170 |
+
def __call__(
|
171 |
+
self,
|
172 |
+
attn,
|
173 |
+
hidden_states,
|
174 |
+
encoder_hidden_states=None,
|
175 |
+
attention_mask=None,
|
176 |
+
temb=None,
|
177 |
+
):
|
178 |
+
input_ndim = hidden_states.ndim
|
179 |
+
|
180 |
+
if input_ndim == 4:
|
181 |
+
batch_size, channel, height, width = hidden_states.shape
|
182 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
183 |
+
|
184 |
+
batch_size, sequence_length, _ = (
|
185 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
186 |
+
)
|
187 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
188 |
+
|
189 |
+
if attn.group_norm is not None:
|
190 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
191 |
+
|
192 |
+
query = attn.to_q(hidden_states)
|
193 |
+
|
194 |
+
if encoder_hidden_states is None:
|
195 |
+
encoder_hidden_states = hidden_states
|
196 |
+
else:
|
197 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
198 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
|
199 |
+
if attn.norm_cross:
|
200 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
201 |
+
|
202 |
+
key_0 = attn.to_k(encoder_hidden_states)
|
203 |
+
value_0 = attn.to_v(encoder_hidden_states)
|
204 |
+
|
205 |
+
query = attn.head_to_batch_dim(query)
|
206 |
+
key = attn.head_to_batch_dim(key_0)
|
207 |
+
value = attn.head_to_batch_dim(value_0)
|
208 |
+
|
209 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
210 |
+
hidden_states = torch.bmm(attention_probs, value)
|
211 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
212 |
+
|
213 |
+
if self.enabled:
|
214 |
+
if self.mode == 'extract':
|
215 |
+
ks, vs = attn.head_to_batch_dim(self.to_k_ip(key_0)), attn.head_to_batch_dim(self.to_v_ip(value_0))
|
216 |
+
self.store_ks, self.store_vs = self.store_ks+[ks], self.store_vs+[vs]
|
217 |
+
self.store_ks, self.store_vs = torch.cat(self.store_ks,dim=0), torch.cat(self.store_vs,dim=0)
|
218 |
+
|
219 |
+
if self.mode == 'masked_generation':
|
220 |
+
if not self.store_ks.shape[0]==query.shape[0]: self.store_ks,self.store_vs = self.store_ks.repeat(2,1,1), self.store_vs.repeat(2,1,1)
|
221 |
+
mask_id = self.mask_id.clone()
|
222 |
+
mask_id.masked_fill_(self.mask_id==False, -torch.finfo(mask_id.dtype).max)
|
223 |
+
mask_id = rearrange(F.interpolate(mask_id,size=int(math.sqrt(query.shape[1]))),"b c h w -> b c (h w)").repeat(1,query.shape[1],1)
|
224 |
+
mask_id = mask_id.repeat(8,1,1) # 8 is head dim
|
225 |
+
if not mask_id.shape[0]==int(query.shape[0]): mask_id = mask_id.repeat(2,1,1)
|
226 |
+
attention_probs_ref = attn.get_attention_scores(query, self.store_ks, mask_id.to(query.dtype))
|
227 |
+
hidden_states_ref = torch.bmm(attention_probs_ref, self.store_vs)
|
228 |
+
hidden_states_ref = attn.batch_to_head_dim(hidden_states_ref)
|
229 |
+
scale = self.scale.repeat(int(batch_size/self.scale.shape[0])).unsqueeze(-1).unsqueeze(-1) if type(self.scale)==torch.Tensor else self.scale
|
230 |
+
if self.mask_ig == None:
|
231 |
+
hidden_states = hidden_states + scale * hidden_states_ref * self.scale_learnable
|
232 |
+
else:
|
233 |
+
mask_ig = rearrange(F.interpolate(self.mask_ig,size=int(math.sqrt(query.shape[1]))),"b c h w -> b (h w) c")
|
234 |
+
if not mask_ig.shape[0]==hidden_states_ref.shape[0]: mask_ig = mask_ig.repeat(2,1,1)
|
235 |
+
hidden_states = hidden_states + scale * hidden_states_ref * mask_ig * self.scale_learnable
|
236 |
+
|
237 |
+
# linear proj
|
238 |
+
hidden_states = attn.to_out[0](hidden_states)
|
239 |
+
# dropout
|
240 |
+
hidden_states = attn.to_out[1](hidden_states)
|
241 |
+
|
242 |
+
if input_ndim == 4:
|
243 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
244 |
+
|
245 |
+
return hidden_states
|
model_lib/moMA_generator.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
4 |
+
from PIL import Image
|
5 |
+
from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
|
6 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def get_subject_idx(model,prompt,src_subject,device):
|
11 |
+
tokenized_prompt = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",).to(device)
|
12 |
+
input_ids = tokenized_prompt['input_ids']
|
13 |
+
src_subject_idxs = []
|
14 |
+
for subject,input_id in zip(src_subject,input_ids):
|
15 |
+
src_subject_token_id = [model.tokenizer.encode(i, add_special_tokens=False)[0] for i in subject.split(' ')]
|
16 |
+
src_subject_idxs = [i for i, x in enumerate(input_id.tolist()) if x in src_subject_token_id]
|
17 |
+
return [src_subject_idxs]
|
18 |
+
|
19 |
+
|
20 |
+
def add_function(model):
|
21 |
+
@torch.no_grad()
|
22 |
+
def generate_with_adapters(
|
23 |
+
model,
|
24 |
+
prompt_embeds,
|
25 |
+
num_inference_steps,
|
26 |
+
generator,
|
27 |
+
t_range=list(range(0,950)),
|
28 |
+
):
|
29 |
+
|
30 |
+
latents = model.prepare_latents(prompt_embeds.shape[0]//2,4,512,512,prompt_embeds.dtype,prompt_embeds.device,generator)
|
31 |
+
|
32 |
+
model.scheduler.set_timesteps(num_inference_steps)
|
33 |
+
|
34 |
+
iterator = tqdm.tqdm(model.scheduler.timesteps)
|
35 |
+
mask_ig_prev = None
|
36 |
+
for i, t in enumerate(iterator):
|
37 |
+
if not t in t_range:
|
38 |
+
model.moMA_generator.toggle_enable_flag('cross')
|
39 |
+
else:
|
40 |
+
model.moMA_generator.toggle_enable_flag('all')
|
41 |
+
|
42 |
+
latent_model_input = torch.cat([latents] * 2)
|
43 |
+
noise_pred = model.unet(
|
44 |
+
latent_model_input,
|
45 |
+
t,
|
46 |
+
encoder_hidden_states=prompt_embeds,
|
47 |
+
return_dict=False,
|
48 |
+
)[0]
|
49 |
+
|
50 |
+
# perform guidance
|
51 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
52 |
+
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
|
53 |
+
|
54 |
+
latents = model.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
55 |
+
|
56 |
+
mask_ig_prev = (get_mask_from_cross(model.unet.attn_processors))[latents.shape[0]:]
|
57 |
+
|
58 |
+
model.moMA_generator.set_self_mask('self','ig',mask_ig_prev)
|
59 |
+
model.moMA_generator.set_self_mask('cross',mask=mask_ig_prev.clone().detach())
|
60 |
+
|
61 |
+
image = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0]
|
62 |
+
return image ,mask_ig_prev.repeat(1,3,1,1) if (not mask_ig_prev==None) else None
|
63 |
+
model.generate_with_adapters = generate_with_adapters
|
64 |
+
|
65 |
+
|
66 |
+
class ImageProjModel(torch.nn.Module):
|
67 |
+
"""Projection Model"""
|
68 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.cross_attention_dim = cross_attention_dim
|
72 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
73 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
74 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
75 |
+
|
76 |
+
def forward(self, image_embeds):
|
77 |
+
embeds = image_embeds
|
78 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
79 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
80 |
+
return clip_extra_context_tokens
|
81 |
+
|
82 |
+
|
83 |
+
class MoMA_generator:
|
84 |
+
def __init__(self, device,args):
|
85 |
+
self.args = args
|
86 |
+
self.device = device
|
87 |
+
|
88 |
+
noise_scheduler = DDIMScheduler(num_train_timesteps=1000,beta_start=0.00085,beta_end=0.012,beta_schedule="scaled_linear",clip_sample=False,set_alpha_to_one=False,steps_offset=1,)
|
89 |
+
|
90 |
+
print('Loading VAE: stabilityai--sd-vae-ft-mse...')
|
91 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
92 |
+
|
93 |
+
print('Loading StableDiffusion: Realistic_Vision...')
|
94 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
95 |
+
"SG161222/Realistic_Vision_V4.0_noVAE",
|
96 |
+
torch_dtype=torch.bfloat16,
|
97 |
+
scheduler=noise_scheduler,
|
98 |
+
vae=vae,
|
99 |
+
feature_extractor=None,
|
100 |
+
safety_checker=None,
|
101 |
+
).to(self.device)
|
102 |
+
|
103 |
+
self.unet = self.pipe.unet
|
104 |
+
add_function(self.pipe)
|
105 |
+
self.pipe.moMA_generator = self
|
106 |
+
|
107 |
+
self.set_ip_adapter()
|
108 |
+
self.image_proj_model = self.init_proj()
|
109 |
+
|
110 |
+
def init_proj(self):
|
111 |
+
image_proj_model = ImageProjModel(
|
112 |
+
cross_attention_dim=768,
|
113 |
+
clip_embeddings_dim=1024,
|
114 |
+
clip_extra_context_tokens=4,
|
115 |
+
).to(self.device, dtype=torch.bfloat16)
|
116 |
+
return image_proj_model
|
117 |
+
|
118 |
+
def set_ip_adapter(self):
|
119 |
+
unet = self.unet
|
120 |
+
attn_procs = {}
|
121 |
+
for name in unet.attn_processors.keys():
|
122 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
123 |
+
if name.startswith("mid_block"):
|
124 |
+
hidden_size = unet.config.block_out_channels[-1]
|
125 |
+
elif name.startswith("up_blocks"):
|
126 |
+
block_id = int(name[len("up_blocks.")])
|
127 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
128 |
+
elif name.startswith("down_blocks"):
|
129 |
+
block_id = int(name[len("down_blocks.")])
|
130 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
131 |
+
if cross_attention_dim is None:
|
132 |
+
attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
|
133 |
+
else:
|
134 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
|
135 |
+
unet.set_attn_processor(attn_procs)
|
136 |
+
|
137 |
+
@torch.inference_mode()
|
138 |
+
def get_image_embeds_CFG(self, llava_emb):
|
139 |
+
clip_image_embeds = llava_emb
|
140 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
141 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
142 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
143 |
+
|
144 |
+
def get_image_crossAttn_feature(
|
145 |
+
self,
|
146 |
+
llava_emb,
|
147 |
+
num_samples=1,
|
148 |
+
):
|
149 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_CFG(llava_emb)
|
150 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
151 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
152 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
153 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
154 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
155 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
156 |
+
|
157 |
+
# feature are from self-attention layers of Unet: feed reference image to Unet with t=0
|
158 |
+
def get_image_selfAttn_feature(
|
159 |
+
self,
|
160 |
+
pil_image,
|
161 |
+
prompt,
|
162 |
+
):
|
163 |
+
self.toggle_enable_flag('self')
|
164 |
+
self.toggle_extract_inject_flag('self', 'extract')
|
165 |
+
tokenized_prompt = self.pipe.tokenizer(prompt,padding="max_length",truncation=True,return_tensors="pt",).to(self.device)
|
166 |
+
text_embeddings = self.pipe.text_encoder(input_ids=tokenized_prompt.input_ids)[0]
|
167 |
+
|
168 |
+
ref_image = pil_image
|
169 |
+
ref_image.to(self.device)
|
170 |
+
|
171 |
+
with torch.no_grad(): latents = self.pipe.vae.encode(ref_image).latent_dist.sample()
|
172 |
+
latents = latents * self.pipe.vae.config.scaling_factor
|
173 |
+
|
174 |
+
noise = torch.randn_like(latents)
|
175 |
+
timesteps = torch.tensor([0],device=latents.device).long() # fixed to 0
|
176 |
+
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timesteps)
|
177 |
+
|
178 |
+
_ = self.unet(noisy_latents,timestep=timesteps,encoder_hidden_states=text_embeddings)["sample"]
|
179 |
+
# features are stored in attn_processors
|
180 |
+
|
181 |
+
return None
|
182 |
+
|
183 |
+
@torch.no_grad()
|
184 |
+
def generate_with_MoMA(
|
185 |
+
self,
|
186 |
+
batch,
|
187 |
+
llava_emb=None,
|
188 |
+
seed=None,
|
189 |
+
device='cuda',
|
190 |
+
):
|
191 |
+
self.reset_all()
|
192 |
+
img_ig,mask_id,subject,prompt = batch['image'].half().to(device),batch['mask'].half().to(device),batch['label'][0],batch['text'][0]
|
193 |
+
|
194 |
+
prompt = [f"photo of a {subject}. "+ prompt]
|
195 |
+
subject_idx = get_subject_idx(self.pipe,prompt,[subject],self.device)
|
196 |
+
negative_prompt = None
|
197 |
+
|
198 |
+
# get context-cross-attention feature (from MLLM decoder)
|
199 |
+
cond_llava_embeds, uncond_llava_embeds = self.get_image_crossAttn_feature(llava_emb,num_samples=1)
|
200 |
+
# get subject-cross-attention feature (from Unet)
|
201 |
+
self.get_image_selfAttn_feature(img_ig,subject) # features are stored in attn_processors
|
202 |
+
|
203 |
+
with torch.inference_mode():
|
204 |
+
prompt_embeds = self.pipe._encode_prompt(
|
205 |
+
prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
206 |
+
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
207 |
+
prompt_embeds = torch.cat([prompt_embeds_, cond_llava_embeds], dim=1)
|
208 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_llava_embeds], dim=1)
|
209 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
210 |
+
|
211 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
212 |
+
|
213 |
+
self.set_self_mask('eraseAll')
|
214 |
+
self.toggle_enable_flag('all')
|
215 |
+
self.toggle_extract_inject_flag('all','masked_generation')
|
216 |
+
self.set_self_mask('self','id',mask_id)
|
217 |
+
self.set_cross_subject_idxs(subject_idx)
|
218 |
+
|
219 |
+
images, mask = self.pipe.generate_with_adapters(
|
220 |
+
self.pipe,
|
221 |
+
prompt_embeds,
|
222 |
+
50,
|
223 |
+
generator,
|
224 |
+
)
|
225 |
+
images = torch.clip((images+1)/2.0,min=0.0,max=1.0)
|
226 |
+
|
227 |
+
return images.cpu(), mask.cpu()
|
228 |
+
|
229 |
+
def set_selfAttn_strength(self, strength):
|
230 |
+
for attn_processor in self.unet.attn_processors.values():
|
231 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
232 |
+
attn_processor.scale = 1.0
|
233 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
234 |
+
attn_processor.scale = strength
|
235 |
+
|
236 |
+
def set_cross_subject_idxs(self, subject_idxs):
|
237 |
+
for attn_processor in self.unet.attn_processors.values():
|
238 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
239 |
+
attn_processor.subject_idxs = subject_idxs
|
240 |
+
|
241 |
+
def set_self_mask(self,mode,id_ig='', mask=None): #only have effect on self attn of the generation process
|
242 |
+
for attn_processor in self.unet.attn_processors.values():
|
243 |
+
if mode == 'eraseAll':
|
244 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
245 |
+
attn_processor.mask_id,attn_processor.mask_ig = None,None
|
246 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
247 |
+
attn_processor.mask_i, attn_processor.mask_ig_prev = None, None
|
248 |
+
if mode == 'self':
|
249 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
250 |
+
if id_ig == 'id':attn_processor.mask_id = mask
|
251 |
+
if id_ig == 'ig':attn_processor.mask_ig = mask
|
252 |
+
if mode == 'cross':
|
253 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
254 |
+
attn_processor.mask_ig_prev = mask
|
255 |
+
|
256 |
+
def toggle_enable_flag(self, processor_enable_mode):
|
257 |
+
for attn_processor in self.unet.attn_processors.values():
|
258 |
+
if processor_enable_mode == 'cross':
|
259 |
+
if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = True
|
260 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = False
|
261 |
+
if processor_enable_mode == 'self':
|
262 |
+
if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = False
|
263 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = True
|
264 |
+
if processor_enable_mode == 'all':
|
265 |
+
attn_processor.enabled = True
|
266 |
+
if processor_enable_mode == 'none':
|
267 |
+
attn_processor.enabled = False
|
268 |
+
|
269 |
+
def toggle_extract_inject_flag(self, processor_name, mode): # mode: str, 'extract' or 'inject' or 'both'(cross only)
|
270 |
+
for attn_processor in self.unet.attn_processors.values():
|
271 |
+
if processor_name == 'cross':
|
272 |
+
if isinstance(attn_processor, IPAttnProcessor):attn_processor.mode = mode
|
273 |
+
if processor_name == 'self':
|
274 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.mode = mode
|
275 |
+
if processor_name == 'all':
|
276 |
+
attn_processor.mode = mode
|
277 |
+
|
278 |
+
def reset_all(self,keep_self=False):
|
279 |
+
for attn_processor in self.unet.attn_processors.values():
|
280 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
281 |
+
attn_processor.store_attn, attn_processor.subject_idxs, attn_processor.mask_i, attn_processor.mask_ig_prev, self.subject_idxs = None, None, None, None, None
|
282 |
+
|
283 |
+
if isinstance(attn_processor, IPAttnProcessor_Self):
|
284 |
+
attn_processor.mask_id, attn_processor.mask_ig = None, None
|
285 |
+
if not keep_self: attn_processor.store_ks, attn_processor.store_vs = [], []
|
model_lib/modules.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import List, Optional
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
from torchvision.transforms import ToPILImage
|
8 |
+
from model_lib.moMA_generator import MoMA_generator
|
9 |
+
from transformers.activations import ACT2FN
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
+
from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA
|
13 |
+
|
14 |
+
from llava.model.builder import load_pretrained_model
|
15 |
+
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
|
16 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
17 |
+
|
18 |
+
def add_function(model):
|
19 |
+
def my_llava_forward(
|
20 |
+
self,
|
21 |
+
input_ids: torch.LongTensor = None,
|
22 |
+
attention_mask: Optional[torch.Tensor] = None,
|
23 |
+
position_ids: Optional[torch.LongTensor] = None,
|
24 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
25 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
26 |
+
labels: Optional[torch.LongTensor] = None,
|
27 |
+
use_cache: Optional[bool] = None,
|
28 |
+
output_attentions: Optional[bool] = None,
|
29 |
+
output_hidden_states: Optional[bool] = None,
|
30 |
+
images: Optional[torch.FloatTensor] = None,
|
31 |
+
return_dict: Optional[bool] = None,
|
32 |
+
):
|
33 |
+
(_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images)
|
34 |
+
|
35 |
+
outputs = self.model(
|
36 |
+
input_ids=None,
|
37 |
+
attention_mask=attention_mask,
|
38 |
+
position_ids=position_ids,
|
39 |
+
past_key_values=None,
|
40 |
+
inputs_embeds=inputs_embeds,
|
41 |
+
use_cache=True,
|
42 |
+
output_attentions=False,
|
43 |
+
output_hidden_states=False,
|
44 |
+
return_dict=True,
|
45 |
+
)
|
46 |
+
return outputs[0]
|
47 |
+
|
48 |
+
model.my_llava_forward = my_llava_forward
|
49 |
+
|
50 |
+
|
51 |
+
class LlamaMLP_mapping(nn.Module):
|
52 |
+
def __init__(self, hidden_size,hidden_size_out):
|
53 |
+
super().__init__()
|
54 |
+
self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out
|
55 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
|
56 |
+
self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
|
57 |
+
self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False)
|
58 |
+
self.act_fn = ACT2FN["silu"]
|
59 |
+
self.act_fn_output = ACT2FN["tanh"]
|
60 |
+
self.init_linear()
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
64 |
+
return down_proj
|
65 |
+
|
66 |
+
def init_linear(self):
|
67 |
+
torch.nn.init.xavier_normal_(self.gate_proj.weight)
|
68 |
+
self.gate_proj.weight.data=self.gate_proj.weight.data/4.0
|
69 |
+
torch.nn.init.xavier_normal_(self.up_proj.weight)
|
70 |
+
self.up_proj.weight.data=self.up_proj.weight.data/4.0
|
71 |
+
torch.nn.init.xavier_normal_(self.down_proj.weight)
|
72 |
+
self.down_proj.weight.data=self.down_proj.weight.data/4.0
|
73 |
+
|
74 |
+
class MoMA_main_modal(nn.Module):
|
75 |
+
def __init__(self,args):
|
76 |
+
super().__init__()
|
77 |
+
self.args = args
|
78 |
+
self.device = args.device
|
79 |
+
|
80 |
+
self.moMA_generator = MoMA_generator(self.device,args)
|
81 |
+
self.unet = self.moMA_generator.pipe.unet
|
82 |
+
self.vae = self.moMA_generator.pipe.vae
|
83 |
+
|
84 |
+
print('Loading MoMA: its Multi-modal LLM...')
|
85 |
+
model_name = get_model_name_from_path(args.model_path)
|
86 |
+
self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
|
87 |
+
|
88 |
+
add_function(self.model_llava)
|
89 |
+
|
90 |
+
self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.bfloat16)
|
91 |
+
self.load_saved_components()
|
92 |
+
self.freeze_modules()
|
93 |
+
|
94 |
+
def load_saved_components(self):
|
95 |
+
if not os.path.exists(self.args.load_attn_adapters):
|
96 |
+
print('Loading Attentions and LLM mappings...')
|
97 |
+
hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1]))
|
98 |
+
|
99 |
+
#load attention adapters and self cross attentions
|
100 |
+
state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu")
|
101 |
+
self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"])
|
102 |
+
attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
103 |
+
attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False)
|
104 |
+
|
105 |
+
#load LLM projectors
|
106 |
+
self.load_state_dict(state_dict['llm_mapping'],strict=False)
|
107 |
+
|
108 |
+
def freeze_modules(self):
|
109 |
+
all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping]
|
110 |
+
for module in all_modules:
|
111 |
+
module.train = False
|
112 |
+
module.requires_grad_(False)
|
113 |
+
|
114 |
+
def forward_MLLM(self,batch):
|
115 |
+
llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
|
116 |
+
|
117 |
+
input_ids,attention_masks,position_ids = [],[],[]
|
118 |
+
for subject,prompt in zip(subjects,prompts):
|
119 |
+
prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *"
|
120 |
+
input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
121 |
+
attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device)
|
122 |
+
position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device)
|
123 |
+
|
124 |
+
position_ids += [position_id]
|
125 |
+
attention_masks += [attention_mask[0]]
|
126 |
+
input_ids += [input_id[0]]
|
127 |
+
|
128 |
+
input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
|
129 |
+
position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
|
130 |
+
attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
|
131 |
+
|
132 |
+
output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds)
|
133 |
+
output = self.mapping(output)
|
134 |
+
return output[:,-1,:]
|
135 |
+
|
136 |
+
def reset(self):
|
137 |
+
self.moMA_generator.reset_all()
|
138 |
+
|
139 |
+
def generate_images(self, rgb_path, mask_path, subject, prompt, strength=1.0, num=1, seed=0):
|
140 |
+
batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject, mask_path,self)
|
141 |
+
self.moMA_generator.set_selfAttn_strength(strength)
|
142 |
+
|
143 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
|
144 |
+
with torch.no_grad():
|
145 |
+
### key steps
|
146 |
+
llava_emb = self.forward_MLLM(batch).clone().detach()
|
147 |
+
img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device)
|
148 |
+
self.reset()
|
149 |
+
|
150 |
+
result = ToPILImage()(img[0])
|
151 |
+
return result
|
model_lib/utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from torchvision.transforms import ToPILImage
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser(description="Simple example of MoMA.")
|
8 |
+
parser.add_argument("--load_attn_adapters",type=str,default="checkpoints/attn_adapters_projectors.th",help="self_cross attentions and LLM projectors.")
|
9 |
+
parser.add_argument("--output_path",type=str,default="output",help="output directory.")
|
10 |
+
parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
|
11 |
+
args = parser.parse_known_args()[0]
|
12 |
+
args.device = torch.device("cuda", 0)
|
13 |
+
args.load_8bit, args.load_4bit = False, True
|
14 |
+
return args
|
15 |
+
|
16 |
+
def show_PIL_image(tensor):
|
17 |
+
# tensor of shape [3, 3, 512, 512]
|
18 |
+
to_pil = ToPILImage()
|
19 |
+
images = [to_pil(tensor[i]) for i in range(tensor.shape[0])]
|
20 |
+
|
21 |
+
concatenated_image = Image.new('RGB', (images[0].width * 3, images[0].height))
|
22 |
+
x_offset = 0
|
23 |
+
for img in images:
|
24 |
+
concatenated_image.paste(img, (x_offset, 0))
|
25 |
+
x_offset += img.width
|
26 |
+
|
27 |
+
return concatenated_image
|
output/car_A car in autumn with falling leaves..jpg
ADDED
output/car_A wooden sculpture of a car on the table..jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pip
|
2 |
+
einops
|
3 |
+
fastapi
|
4 |
+
gradio
|
5 |
+
numpy
|
6 |
+
requests
|
7 |
+
sentencepiece
|
8 |
+
tokenizers>=0.12.1
|
9 |
+
torch==2.0.1
|
10 |
+
torchvision==0.15.2
|
11 |
+
uvicorn
|
12 |
+
wandb
|
13 |
+
shortuuid
|
14 |
+
httpx==0.24.0
|
15 |
+
deepspeed
|
16 |
+
peft==0.4.0
|
17 |
+
transformers==4.36.2
|
18 |
+
accelerate==0.21.0
|
19 |
+
bitsandbytes==0.41.0
|
20 |
+
scikit-learn==1.2.2
|
21 |
+
sentencepiece==0.1.99
|
22 |
+
einops==0.6.1
|
23 |
+
einops-exts==0.0.4
|
24 |
+
timm==0.6.13
|
25 |
+
gradio_client
|
26 |
+
opencv-python
|
27 |
+
diffusers
|
28 |
+
torchaudio
|
29 |
+
torchmetrics
|
30 |
+
llava-torch
|
31 |
+
rembg
|
32 |
+
pytorch_lightning
|