JackAILab commited on
Commit
9669aec
1 Parent(s): 0cf6544

Upload 292 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. app.py +198 -9
  3. attention.py +288 -0
  4. functions.py +599 -0
  5. images/templates/3f8d901770014c1b8f7f261971f0e92.png +3 -0
  6. images/templates/6577b962b6346df03fea83211daaf48.png +0 -0
  7. images/templates/75583964a834abe33b72f52b1a98e84.png +3 -0
  8. images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg +0 -0
  9. models/BiSeNet/6.jpg +0 -0
  10. models/BiSeNet/__init__.py +2 -0
  11. models/BiSeNet/__pycache__/__init__.cpython-38.pyc +0 -0
  12. models/BiSeNet/__pycache__/model.cpython-38.pyc +0 -0
  13. models/BiSeNet/__pycache__/resnet.cpython-38.pyc +0 -0
  14. models/BiSeNet/evaluate.py +95 -0
  15. models/BiSeNet/face_dataset.py +106 -0
  16. models/BiSeNet/hair.png +0 -0
  17. models/BiSeNet/logger.py +23 -0
  18. models/BiSeNet/loss.py +75 -0
  19. models/BiSeNet/makeup.py +130 -0
  20. models/BiSeNet/makeup/116_1.png +0 -0
  21. models/BiSeNet/makeup/116_3.png +0 -0
  22. models/BiSeNet/makeup/116_lip_ori.png +0 -0
  23. models/BiSeNet/makeup/116_ori.png +0 -0
  24. models/BiSeNet/model.py +283 -0
  25. models/BiSeNet/modules/__init__.py +5 -0
  26. models/BiSeNet/modules/bn.py +130 -0
  27. models/BiSeNet/modules/deeplab.py +84 -0
  28. models/BiSeNet/modules/dense.py +42 -0
  29. models/BiSeNet/modules/functions.py +234 -0
  30. models/BiSeNet/modules/misc.py +21 -0
  31. models/BiSeNet/modules/residual.py +88 -0
  32. models/BiSeNet/modules/src/checks.h +15 -0
  33. models/BiSeNet/modules/src/inplace_abn.cpp +95 -0
  34. models/BiSeNet/modules/src/inplace_abn.h +88 -0
  35. models/BiSeNet/modules/src/inplace_abn_cpu.cpp +119 -0
  36. models/BiSeNet/modules/src/inplace_abn_cuda.cu +333 -0
  37. models/BiSeNet/modules/src/inplace_abn_cuda_half.cu +275 -0
  38. models/BiSeNet/modules/src/utils/checks.h +15 -0
  39. models/BiSeNet/modules/src/utils/common.h +49 -0
  40. models/BiSeNet/modules/src/utils/cuda.cuh +71 -0
  41. models/BiSeNet/optimizer.py +69 -0
  42. models/BiSeNet/prepropess_data.py +38 -0
  43. models/BiSeNet/resnet.py +109 -0
  44. models/BiSeNet/test.py +90 -0
  45. models/BiSeNet/train.py +179 -0
  46. models/BiSeNet/transform.py +129 -0
  47. models/BiSeNet_pretrained_for_ConsistentID.pth +3 -0
  48. models/LLaVA/.devcontainer/Dockerfile +53 -0
  49. models/LLaVA/.devcontainer/devcontainer.env +2 -0
  50. models/LLaVA/.devcontainer/devcontainer.json +71 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text
37
+ images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text
38
+ models/LLaVA/images/demo_cli.gif filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,14 +1,203 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
7
 
8
- @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import os
4
+ import glob
5
+ import numpy as np
6
+ from datetime import datetime
7
+ from PIL import Image
8
+ from diffusers.utils import load_image
9
+ from diffusers import EulerDiscreteScheduler
10
+ from pipline_StableDiffusion_ConsistentID import ConsistentIDStableDiffusionPipeline
11
+ import sys
12
+ sys.path.append("./models/LLaVA")
13
+ from llava.model.builder import load_pretrained_model
14
+ from llava.mm_utils import get_model_name_from_path
15
+ from llava.eval.run_llava import eval_model
16
 
17
+ # Load Lava for prompt enhancement
18
+ llva_model_path = "/data6/huangjiehui_m22/pretrained_model/llava-v1.5-7b" #TODO
19
+ llva_tokenizer, llva_model, llva_image_processor, llva_context_len = load_pretrained_model(
20
+ model_path=llva_model_path,
21
+ model_base=None,
22
+ model_name=get_model_name_from_path(llva_model_path),)
23
 
 
 
 
 
24
 
25
+ @torch.inference_mode()
26
+ def Enhance_prompt(prompt,select_images):
27
+
28
+ llva_prompt = f'Please ignore the image. Enhance the following text prompt for me. You can associate more details with the character\'s gesture, environment, and decent clothing:"{prompt}".'
29
+ args = type('Args', (), {
30
+ "model_path": llva_model_path,
31
+ "model_base": None,
32
+ "model_name": get_model_name_from_path(llva_model_path),
33
+ "query": llva_prompt,
34
+ "conv_mode": None,
35
+ "image_file": select_images,
36
+ "sep": ",",
37
+ "temperature": 0,
38
+ "top_p": None,
39
+ "num_beams": 1,
40
+ "max_new_tokens": 512
41
+ })()
42
+ Enhanced_prompt = eval_model(args, llva_tokenizer, llva_model, llva_image_processor)
43
+
44
+ return Enhanced_prompt
45
+
46
+ # print(gr.__version__)
47
+ # 4.16.0
48
+ os.environ['GRADIO_TEMP_DIR'] = "/data6/huangjiehui_m22/z_benke/liaost/ConsistentID/images/gradio_tmp" #TODO
49
+
50
+ script_directory = os.path.dirname(os.path.realpath(__file__))
51
+ device = "cuda"
52
+ # TODO
53
+ base_model_path = "/data6/huangjiehui_m22/pretrained_model/Realistic_Vision_V6.0_B1_noVAE" # TODO
54
+ consistentID_path = "/data6/huangjiehui_m22/z_benke/liaost/ConsistentID/models/ConsistentID_model_facemask_pretrain_50w.bin" # TODO
55
+
56
+ ### Load base model
57
+ pipe = ConsistentIDStableDiffusionPipeline.from_pretrained(
58
+ base_model_path,
59
+ torch_dtype=torch.float16,
60
+ use_safetensors=True,
61
+ variant="fp16"
62
+ ).to(device)
63
+
64
+ ### Load consistentID_model checkpoint
65
+ pipe.load_ConsistentID_model(
66
+ os.path.dirname(consistentID_path),
67
+ subfolder="",
68
+ weight_name=os.path.basename(consistentID_path),
69
+ trigger_word="img",
70
+ )
71
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
72
+
73
+ def process(selected_template_images,costum_image,prompt
74
+ ,negative_prompt,prompt_selected,retouching,model_selected_tab,prompt_selected_tab,width,height,merge_steps):
75
+
76
+ if model_selected_tab==0:
77
+ select_images = load_image(Image.open(selected_template_images))
78
+ else:
79
+ select_images = load_image(Image.fromarray(costum_image))
80
+
81
+ if prompt_selected_tab==0:
82
+ prompt = prompt_selected
83
+ negative_prompt = ""
84
+ need_safetycheck = False
85
+ else:
86
+ need_safetycheck = True
87
+
88
+
89
+ # hyper-parameter
90
+ num_steps = 50
91
+ # merge_steps = 30
92
+
93
+
94
+ if prompt == "":
95
+ prompt = "A man, in a forest"
96
+ prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
97
+ prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
98
+ else:
99
+ prompt=Enhance_prompt(prompt,Image.new('RGB', (200, 200), color = 'white'))
100
+ print(prompt)
101
+ pass
102
+
103
+ if negative_prompt == "":
104
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
105
+
106
+ #Extend Prompt
107
+ prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
108
+
109
+ negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
110
+ negative_prompt = negative_prompt + negtive_prompt_group
111
+
112
+ seed = torch.randint(0, 1000, (1,)).item()
113
+ generator = torch.Generator(device=device).manual_seed(seed)
114
+
115
+ images = pipe(
116
+ prompt=prompt,
117
+ width=width,
118
+ height=height,
119
+ input_id_images=select_images,
120
+ negative_prompt=negative_prompt,
121
+ num_images_per_prompt=1,
122
+ num_inference_steps=num_steps,
123
+ start_merge_step=merge_steps,
124
+ generator=generator,
125
+ retouching=retouching,
126
+ need_safetycheck=need_safetycheck,
127
+ ).images[0]
128
+
129
+ current_date = datetime.today()
130
+ return np.array(images)
131
+
132
+ # Gets the templates
133
+ script_directory = os.path.dirname(os.path.realpath(__file__))
134
+ preset_template = glob.glob("./images/templates/*.png")
135
+ preset_template = preset_template + glob.glob("./images/templates/*.jpg")
136
+
137
+
138
+ with gr.Blocks(title="ConsistentID Demo") as demo:
139
+ gr.Markdown("# ConsistentID Demo")
140
+ gr.Markdown("\
141
+ Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
142
+ gr.Markdown("\
143
+ If you find our work interesting, please leave a star in GitHub for us!<br>\
144
+ https://github.com/JackAILab/ConsistentID")
145
+ with gr.Row():
146
+ with gr.Column():
147
+ model_selected_tab = gr.State(0)
148
+ with gr.TabItem("template images") as template_images_tab:
149
+ template_gallery_list = [(i, i) for i in preset_template]
150
+ gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
151
+
152
+ def select_function(evt: gr.SelectData):
153
+ return preset_template[evt.index]
154
+
155
+ selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
156
+ gallery.select(select_function, None, selected_template_images)
157
+ with gr.TabItem("Upload Image") as upload_image_tab:
158
+ costum_image = gr.Image(label="Upload Image")
159
+
160
+ model_selected_tabs = [template_images_tab, upload_image_tab]
161
+ for i, tab in enumerate(model_selected_tabs):
162
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
163
+
164
+ with gr.Column():
165
+ prompt_selected_tab = gr.State(0)
166
+ with gr.TabItem("template prompts") as template_prompts_tab:
167
+ prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
168
+ "A woman in a wedding dress",
169
+ "A woman, queen, in a gorgeous palace",
170
+ "A man sitting at the beach with sunset",
171
+ "A person, police officer, half body shot",
172
+ "A man, sailor, in a boat above ocean",
173
+ "A women wearing headphone, listening music",
174
+ "A man, firefighter, half body shot"], label=f"prepared prompts")
175
+
176
+ with gr.TabItem("custom prompt") as custom_prompt_tab:
177
+ prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
178
+ nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
179
+
180
+ prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
181
+ for i, tab in enumerate(prompt_selected_tabs):
182
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
183
+
184
+ retouching = gr.Checkbox(label="face retouching",value=False)
185
+ width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
186
+ height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
187
+ width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
188
+ height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
189
+ merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
190
+
191
+ btn = gr.Button("Run")
192
+ with gr.Column():
193
+ out = gr.Image(label="Output")
194
+ gr.Markdown('''
195
+ N.B.:<br/>
196
+ - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
197
+ - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
198
+ - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
199
+ ''')
200
+ btn.click(fn=process, inputs=[selected_template_images,costum_image,prompt,nagetive_prompt,prompt_selected,retouching
201
+ ,model_selected_tab,prompt_selected_tab,width,height,merge_steps], outputs=out)
202
+
203
+ demo.launch()
attention.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.models.lora import LoRALinearLayer
5
+ from functions import AttentionMLP
6
+
7
+
8
+ class FuseModule(nn.Module):
9
+ def __init__(self, embed_dim):
10
+ super().__init__()
11
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
12
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
13
+ self.layer_norm = nn.LayerNorm(embed_dim)
14
+
15
+ def fuse_fn(self, prompt_embeds, id_embeds):
16
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
17
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
18
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
19
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
20
+ return stacked_id_embeds
21
+
22
+ def forward(
23
+ self,
24
+ prompt_embeds,
25
+ id_embeds,
26
+ class_tokens_mask,
27
+ valid_id_mask,
28
+ ) -> torch.Tensor:
29
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
30
+ batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5
31
+ seq_length = prompt_embeds.shape[1] # 77
32
+ flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1])
33
+ # flat_id_embeds torch.Size([5, 1, 768])
34
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
35
+ # valid_id_embeds torch.Size([4, 1, 768])
36
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768])
37
+ class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77])
38
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768])
39
+ image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768])
40
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768])
41
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
42
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
43
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
44
+
45
+ return updated_prompt_embeds
46
+
47
+ class MLP(nn.Module):
48
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
49
+ super().__init__()
50
+ if use_residual:
51
+ assert in_dim == out_dim
52
+ self.layernorm = nn.LayerNorm(in_dim)
53
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
54
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
55
+ self.use_residual = use_residual
56
+ self.act_fn = nn.GELU()
57
+
58
+ def forward(self, x):
59
+
60
+ residual = x
61
+ x = self.layernorm(x)
62
+ x = self.fc1(x)
63
+ x = self.act_fn(x)
64
+ x = self.fc2(x)
65
+ if self.use_residual:
66
+ x = x + residual
67
+ return x
68
+
69
+ class FacialEncoder(nn.Module):
70
+ def __init__(self,image_CLIPModel_encoder=None):
71
+ super().__init__()
72
+ self.visual_projection = AttentionMLP()
73
+ self.fuse_module = FuseModule(768)
74
+
75
+ def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask):
76
+
77
+ bs, num_inputs, token_length, image_dim = multi_image_embeds.shape
78
+ multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim)
79
+ id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768])
80
+ id_embeds = id_embeds.view(bs, num_inputs, 1, -1)
81
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask)
82
+
83
+ return updated_prompt_embeds
84
+
85
+ class Consistent_AttProcessor(nn.Module):
86
+
87
+ def __init__(
88
+ self,
89
+ hidden_size=None,
90
+ cross_attention_dim=None,
91
+ rank=4,
92
+ network_alpha=None,
93
+ lora_scale=1.0,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.rank = rank
98
+ self.lora_scale = lora_scale
99
+
100
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
101
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
102
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
103
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
104
+
105
+ def __call__(
106
+ self,
107
+ attn,
108
+ hidden_states,
109
+ encoder_hidden_states=None,
110
+ attention_mask=None,
111
+ temb=None,
112
+ ):
113
+ residual = hidden_states
114
+
115
+ if attn.spatial_norm is not None:
116
+ hidden_states = attn.spatial_norm(hidden_states, temb)
117
+
118
+ input_ndim = hidden_states.ndim
119
+
120
+ if input_ndim == 4:
121
+ batch_size, channel, height, width = hidden_states.shape
122
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
123
+
124
+ batch_size, sequence_length, _ = (
125
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
126
+ )
127
+
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+
130
+ if attn.group_norm is not None:
131
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ elif attn.norm_cross:
138
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
139
+
140
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
141
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
142
+
143
+ query = attn.head_to_batch_dim(query)
144
+ key = attn.head_to_batch_dim(key)
145
+ value = attn.head_to_batch_dim(value)
146
+
147
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
148
+ hidden_states = torch.bmm(attention_probs, value)
149
+ hidden_states = attn.batch_to_head_dim(hidden_states)
150
+
151
+ # linear proj
152
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
153
+ # dropout
154
+ hidden_states = attn.to_out[1](hidden_states)
155
+
156
+ if input_ndim == 4:
157
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
158
+
159
+ if attn.residual_connection:
160
+ hidden_states = hidden_states + residual
161
+
162
+ hidden_states = hidden_states / attn.rescale_output_factor
163
+
164
+ return hidden_states
165
+
166
+
167
+ class Consistent_IPAttProcessor(nn.Module):
168
+
169
+ def __init__(
170
+ self,
171
+ hidden_size,
172
+ cross_attention_dim=None,
173
+ rank=4,
174
+ network_alpha=None,
175
+ lora_scale=1.0,
176
+ scale=1.0,
177
+ num_tokens=4):
178
+ super().__init__()
179
+
180
+ self.rank = rank
181
+ self.lora_scale = lora_scale
182
+ self.num_tokens = num_tokens
183
+
184
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
185
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
186
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
187
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
188
+
189
+
190
+ self.hidden_size = hidden_size
191
+ self.cross_attention_dim = cross_attention_dim
192
+ self.scale = scale
193
+
194
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
195
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
196
+
197
+ for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]:
198
+ for param in module.parameters():
199
+ param.requires_grad = False
200
+
201
+ def __call__(
202
+ self,
203
+ attn,
204
+ hidden_states,
205
+ encoder_hidden_states=None,
206
+ attention_mask=None,
207
+ scale=1.0,
208
+ temb=None,
209
+ ):
210
+ residual = hidden_states
211
+
212
+ if attn.spatial_norm is not None:
213
+ hidden_states = attn.spatial_norm(hidden_states, temb)
214
+
215
+ input_ndim = hidden_states.ndim
216
+
217
+ if input_ndim == 4:
218
+ batch_size, channel, height, width = hidden_states.shape
219
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
220
+
221
+ batch_size, sequence_length, _ = (
222
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
223
+ )
224
+
225
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
226
+
227
+ if attn.group_norm is not None:
228
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
229
+
230
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
231
+
232
+ if encoder_hidden_states is None:
233
+ encoder_hidden_states = hidden_states
234
+ else:
235
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
236
+ encoder_hidden_states, ip_hidden_states = (
237
+ encoder_hidden_states[:, :end_pos, :],
238
+ encoder_hidden_states[:, end_pos:, :],
239
+ )
240
+ if attn.norm_cross:
241
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
242
+
243
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
244
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
245
+
246
+ inner_dim = key.shape[-1]
247
+ head_dim = inner_dim // attn.heads
248
+
249
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ hidden_states = F.scaled_dot_product_attention(
254
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
255
+ )
256
+
257
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
258
+ hidden_states = hidden_states.to(query.dtype)
259
+
260
+ ip_key = self.to_k_ip(ip_hidden_states)
261
+ ip_value = self.to_v_ip(ip_hidden_states)
262
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
264
+
265
+
266
+ ip_hidden_states = F.scaled_dot_product_attention(
267
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
268
+ )
269
+
270
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
271
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
272
+
273
+ hidden_states = hidden_states + self.scale * ip_hidden_states
274
+
275
+ # linear proj
276
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
277
+ # dropout
278
+ hidden_states = attn.to_out[1](hidden_states)
279
+
280
+ if input_ndim == 4:
281
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
282
+
283
+ if attn.residual_connection:
284
+ hidden_states = hidden_states + residual
285
+
286
+ hidden_states = hidden_states / attn.rescale_output_factor
287
+
288
+ return hidden_states
functions.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import types
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import cv2
8
+ import re
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from einops.layers.torch import Rearrange
12
+ from PIL import Image
13
+
14
+ def extract_first_sentence(text):
15
+ end_index = text.find('.')
16
+ if end_index != -1:
17
+ first_sentence = text[:end_index + 1]
18
+ return first_sentence.strip()
19
+ else:
20
+ return text.strip()
21
+
22
+ import re
23
+ def remove_duplicate_keywords(text, keywords):
24
+ keyword_counts = {}
25
+
26
+ words = re.findall(r'\b\w+\b|[.,;!?]', text)
27
+
28
+ for keyword in keywords:
29
+ keyword_counts[keyword] = 0
30
+ for i, word in enumerate(words):
31
+ if word.lower() == keyword.lower():
32
+ keyword_counts[keyword] += 1
33
+ if keyword_counts[keyword] > 1:
34
+ words[i] = ""
35
+ processed_text = " ".join(words)
36
+
37
+ return processed_text
38
+
39
+ def process_text_with_markers(text, parsing_mask_list):
40
+ keywords = ["face", "ears", "eyes", "nose", "mouth"]
41
+ text = remove_duplicate_keywords(text, keywords)
42
+ key_parsing_mask_markers = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
43
+ mapping = {
44
+ "Face": "face",
45
+ "Left_Ear": "ears",
46
+ "Right_Ear": "ears",
47
+ "Left_Eye": "eyes",
48
+ "Right_Eye": "eyes",
49
+ "Nose": "nose",
50
+ "Upper_Lip": "mouth",
51
+ "Lower_Lip": "mouth",
52
+ }
53
+ facial_features_align = []
54
+ markers_align = []
55
+ for key in key_parsing_mask_markers:
56
+ if key in parsing_mask_list:
57
+ mapped_key = mapping.get(key, key.lower())
58
+ if mapped_key not in facial_features_align:
59
+ facial_features_align.append(mapped_key)
60
+ markers_align.append("<|"+mapped_key+"|>")
61
+
62
+ text_marked = text
63
+ align_parsing_mask_list = parsing_mask_list
64
+ for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]):
65
+ pattern = rf'\b{feature}\b'
66
+ text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1)
67
+ if text_marked == text_marked_new:
68
+ for key, value in mapping.items():
69
+ if value == feature:
70
+ if key in align_parsing_mask_list:
71
+ del align_parsing_mask_list[key]
72
+
73
+ text_marked = text_marked_new
74
+
75
+ text_marked = text_marked.replace('\n', '')
76
+
77
+ ordered_text = []
78
+ text_none_makers = []
79
+ facial_marked_count = 0
80
+ skip_count = 0
81
+ for marker in markers_align:
82
+ start_idx = text_marked.find(marker)
83
+ end_idx = start_idx + len(marker)
84
+
85
+ while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]:
86
+ start_idx -= 1
87
+
88
+ while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]:
89
+ end_idx += 1
90
+
91
+ context = text_marked[start_idx:end_idx].strip()
92
+ if context == "":
93
+ text_none_makers.append(text_marked[:end_idx])
94
+ else:
95
+ if skip_count!=0:
96
+ skip_count -= 1
97
+ continue
98
+ else:
99
+ ordered_text.append(context + ",")
100
+ text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:]
101
+ text_marked = text_delete_makers
102
+ facial_marked_count += 1
103
+
104
+ align_marked_text = " ".join(ordered_text)
105
+ replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"]
106
+ for item in replace_list:
107
+ align_marked_text = align_marked_text.replace(item, "<|facial|>")
108
+
109
+ return align_marked_text, align_parsing_mask_list
110
+
111
+ def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer):
112
+ input_ids = tokenizer.encode(text)
113
+ image_noun_phrase_end_mask = [False for _ in input_ids]
114
+ facial_noun_phrase_end_mask = [False for _ in input_ids]
115
+ clean_input_ids = []
116
+ clean_index = 0
117
+ image_num = 0
118
+
119
+ for i, id in enumerate(input_ids):
120
+ if id == image_token_id:
121
+ image_noun_phrase_end_mask[clean_index + image_num - 1] = True
122
+ image_num += 1
123
+ elif id == facial_token_id:
124
+ facial_noun_phrase_end_mask[clean_index - 1] = True
125
+ else:
126
+ clean_input_ids.append(id)
127
+ clean_index += 1
128
+
129
+ max_len = tokenizer.model_max_length
130
+
131
+ if len(clean_input_ids) > max_len:
132
+ clean_input_ids = clean_input_ids[:max_len]
133
+ else:
134
+ clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
135
+ max_len - len(clean_input_ids)
136
+ )
137
+
138
+ if len(image_noun_phrase_end_mask) > max_len:
139
+ image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len]
140
+ else:
141
+ image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * (
142
+ max_len - len(image_noun_phrase_end_mask)
143
+ )
144
+
145
+ if len(facial_noun_phrase_end_mask) > max_len:
146
+ facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len]
147
+ else:
148
+ facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * (
149
+ max_len - len(facial_noun_phrase_end_mask)
150
+ )
151
+ clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long)
152
+ image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool)
153
+ facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool)
154
+
155
+ return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0)
156
+
157
+ def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5):
158
+ image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1]
159
+ image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool)
160
+ if len(image_token_idx) < max_num_objects:
161
+ image_token_idx = torch.cat(
162
+ [
163
+ image_token_idx,
164
+ torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long),
165
+ ]
166
+ )
167
+ image_token_idx_mask = torch.cat(
168
+ [
169
+ image_token_idx_mask,
170
+ torch.zeros(
171
+ max_num_objects - len(image_token_idx_mask),
172
+ dtype=torch.bool,
173
+ ),
174
+ ]
175
+ )
176
+ facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1]
177
+ facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool)
178
+ if len(facial_token_idx) < max_num_facials:
179
+ facial_token_idx = torch.cat(
180
+ [
181
+ facial_token_idx,
182
+ torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long),
183
+ ]
184
+ )
185
+ facial_token_idx_mask = torch.cat(
186
+ [
187
+ facial_token_idx_mask,
188
+ torch.zeros(
189
+ max_num_facials - len(facial_token_idx_mask),
190
+ dtype=torch.bool,
191
+ ),
192
+ ]
193
+ )
194
+ image_token_idx = image_token_idx.unsqueeze(0)
195
+ image_token_idx_mask = image_token_idx_mask.unsqueeze(0)
196
+
197
+ facial_token_idx = facial_token_idx.unsqueeze(0)
198
+ facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0)
199
+
200
+ return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask
201
+
202
+ def get_object_localization_loss_for_one_layer(
203
+ cross_attention_scores,
204
+ object_segmaps,
205
+ object_token_idx,
206
+ object_token_idx_mask,
207
+ loss_fn,
208
+ ):
209
+ bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape
210
+ b, max_num_objects, _, _ = object_segmaps.shape
211
+ size = int(num_noise_latents**0.5)
212
+
213
+ object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True)
214
+
215
+ object_segmaps = object_segmaps.view(
216
+ b, max_num_objects, -1
217
+ )
218
+
219
+ num_heads = bxh // b
220
+ cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens)
221
+
222
+
223
+ object_token_attn_prob = torch.gather(
224
+ cross_attention_scores,
225
+ dim=3,
226
+ index=object_token_idx.view(b, 1, 1, max_num_objects).expand(
227
+ b, num_heads, num_noise_latents, max_num_objects
228
+ ),
229
+ )
230
+ object_segmaps = (
231
+ object_segmaps.permute(0, 2, 1)
232
+ .unsqueeze(1)
233
+ .expand(b, num_heads, num_noise_latents, max_num_objects)
234
+ )
235
+ loss = loss_fn(object_token_attn_prob, object_segmaps)
236
+
237
+ loss = loss * object_token_idx_mask.view(b, 1, max_num_objects)
238
+ object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5
239
+ loss = (loss.sum(dim=2) / object_token_cnt).mean()
240
+
241
+ return loss
242
+
243
+
244
+ def get_object_localization_loss(
245
+ cross_attention_scores,
246
+ object_segmaps,
247
+ image_token_idx,
248
+ image_token_idx_mask,
249
+ loss_fn,
250
+ ):
251
+ num_layers = len(cross_attention_scores)
252
+ loss = 0
253
+ for k, v in cross_attention_scores.items():
254
+ layer_loss = get_object_localization_loss_for_one_layer(
255
+ v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn
256
+ )
257
+ loss += layer_loss
258
+ return loss / num_layers
259
+
260
+ def unet_store_cross_attention_scores(unet, attention_scores, layers=5):
261
+ from diffusers.models.attention_processor import Attention
262
+
263
+ UNET_LAYER_NAMES = [
264
+ "down_blocks.0",
265
+ "down_blocks.1",
266
+ "down_blocks.2",
267
+ "mid_block",
268
+ "up_blocks.1",
269
+ "up_blocks.2",
270
+ "up_blocks.3",
271
+ ]
272
+
273
+ start_layer = (len(UNET_LAYER_NAMES) - layers) // 2
274
+ end_layer = start_layer + layers
275
+ applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
276
+
277
+ def make_new_get_attention_scores_fn(name):
278
+ def new_get_attention_scores(module, query, key, attention_mask=None):
279
+ attention_probs = module.old_get_attention_scores(
280
+ query, key, attention_mask
281
+ )
282
+ attention_scores[name] = attention_probs
283
+ return attention_probs
284
+
285
+ return new_get_attention_scores
286
+
287
+ for name, module in unet.named_modules():
288
+ if isinstance(module, Attention) and "attn1" in name:
289
+ if not any(layer in name for layer in applicable_layers):
290
+ continue
291
+
292
+ module.old_get_attention_scores = module.get_attention_scores
293
+ module.get_attention_scores = types.MethodType(
294
+ make_new_get_attention_scores_fn(name), module
295
+ )
296
+ return unet
297
+
298
+ class BalancedL1Loss(nn.Module):
299
+ def __init__(self, threshold=1.0, normalize=False):
300
+ super().__init__()
301
+ self.threshold = threshold
302
+ self.normalize = normalize
303
+
304
+ def forward(self, object_token_attn_prob, object_segmaps):
305
+ if self.normalize:
306
+ object_token_attn_prob = object_token_attn_prob / (
307
+ object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5
308
+ )
309
+ background_segmaps = 1 - object_segmaps
310
+ background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5
311
+ object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5
312
+
313
+ background_loss = (object_token_attn_prob * background_segmaps).sum(
314
+ dim=2
315
+ ) / background_segmaps_sum
316
+
317
+ object_loss = (object_token_attn_prob * object_segmaps).sum(
318
+ dim=2
319
+ ) / object_segmaps_sum
320
+
321
+ return background_loss - object_loss
322
+
323
+ def fetch_mask_raw_image(raw_image, mask_image):
324
+
325
+ mask_image = mask_image.resize(raw_image.size)
326
+ mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image)
327
+
328
+ return mask_raw_image
329
+
330
+ mapping_table = [
331
+ {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]},
332
+ {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]},
333
+ {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]},
334
+ {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]},
335
+ {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]},
336
+ {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]},
337
+ {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]},
338
+ {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]},
339
+ {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]},
340
+ {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]},
341
+ {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]},
342
+ {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]},
343
+ {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]},
344
+ {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]},
345
+ {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]},
346
+ {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]},
347
+ {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]},
348
+ {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]},
349
+ {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]},
350
+ {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]},
351
+ {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]},
352
+ {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]},
353
+ {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]},
354
+ {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]},
355
+ {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]}
356
+ ]
357
+
358
+
359
+ def masks_for_unique_values(image_raw_mask):
360
+
361
+ image_array = np.array(image_raw_mask)
362
+ unique_values, counts = np.unique(image_array, return_counts=True)
363
+ masks_dict = {}
364
+ for value in unique_values:
365
+ binary_image = np.uint8(image_array == value) * 255
366
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
367
+
368
+ mask = np.zeros_like(image_array)
369
+ for contour in contours:
370
+ cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)
371
+
372
+ if value == 0:
373
+ body_part="WithoutBackground"
374
+ mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype)
375
+ masks_dict[body_part] = Image.fromarray(mask2)
376
+
377
+ body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}")
378
+ if body_part.startswith("Unknown_"):
379
+ continue
380
+
381
+ masks_dict[body_part] = Image.fromarray(mask)
382
+
383
+ return masks_dict
384
+ # FFN
385
+ def FeedForward(dim, mult=4):
386
+ inner_dim = int(dim * mult)
387
+ return nn.Sequential(
388
+ nn.LayerNorm(dim),
389
+ nn.Linear(dim, inner_dim, bias=False),
390
+ nn.GELU(),
391
+ nn.Linear(inner_dim, dim, bias=False),
392
+ )
393
+
394
+
395
+ def reshape_tensor(x, heads):
396
+ bs, length, width = x.shape
397
+ x = x.view(bs, length, heads, -1)
398
+ x = x.transpose(1, 2)
399
+ x = x.reshape(bs, heads, length, -1)
400
+ return x
401
+
402
+ class PerceiverAttention(nn.Module):
403
+ def __init__(self, *, dim, dim_head=64, heads=8):
404
+ super().__init__()
405
+ self.scale = dim_head**-0.5
406
+ self.dim_head = dim_head
407
+ self.heads = heads
408
+ inner_dim = dim_head * heads
409
+
410
+ self.norm1 = nn.LayerNorm(dim)
411
+ self.norm2 = nn.LayerNorm(dim)
412
+
413
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
414
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
415
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
416
+
417
+ def forward(self, x, latents):
418
+ """
419
+ Args:
420
+ x (torch.Tensor): image features
421
+ shape (b, n1, D)
422
+ latent (torch.Tensor): latent features
423
+ shape (b, n2, D)
424
+ """
425
+
426
+ x = self.norm1(x)
427
+ latents = self.norm2(latents)
428
+
429
+ b, l, _ = latents.shape
430
+
431
+ q = self.to_q(latents)
432
+ kv_input = torch.cat((x, latents), dim=-2)
433
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
434
+
435
+ q = reshape_tensor(q, self.heads)
436
+ k = reshape_tensor(k, self.heads)
437
+ v = reshape_tensor(v, self.heads)
438
+
439
+ # attention
440
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
441
+ weight = (q * scale) @ (k * scale).transpose(-2, -1)
442
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
443
+ out = weight @ v
444
+
445
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
446
+
447
+ return self.to_out(out)
448
+
449
+ class FacePerceiverResampler(torch.nn.Module):
450
+ def __init__(
451
+ self,
452
+ *,
453
+ dim=768,
454
+ depth=4,
455
+ dim_head=64,
456
+ heads=16,
457
+ embedding_dim=1280,
458
+ output_dim=768,
459
+ ff_mult=4,
460
+ ):
461
+ super().__init__()
462
+
463
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
464
+ self.proj_out = torch.nn.Linear(dim, output_dim)
465
+ self.norm_out = torch.nn.LayerNorm(output_dim)
466
+ self.layers = torch.nn.ModuleList([])
467
+ for _ in range(depth):
468
+ self.layers.append(
469
+ torch.nn.ModuleList(
470
+ [
471
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
472
+ FeedForward(dim=dim, mult=ff_mult),
473
+ ]
474
+ )
475
+ )
476
+ def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280])
477
+ x = self.proj_in(x) # x.torch.Size([2, 257, 768])
478
+ for attn, ff in self.layers:
479
+ latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768])
480
+ latents = ff(latents) + latents # latents.torch.Size([2, 4, 768])
481
+ latents = self.proj_out(latents)
482
+ return self.norm_out(latents)
483
+
484
+ class ProjPlusModel(torch.nn.Module):
485
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
486
+ super().__init__()
487
+
488
+ self.cross_attention_dim = cross_attention_dim
489
+ self.num_tokens = num_tokens
490
+
491
+ self.proj = torch.nn.Sequential(
492
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
493
+ torch.nn.GELU(),
494
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
495
+ )
496
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
497
+
498
+ self.perceiver_resampler = FacePerceiverResampler(
499
+ dim=cross_attention_dim,
500
+ depth=4,
501
+ dim_head=64,
502
+ heads=cross_attention_dim // 64,
503
+ embedding_dim=clip_embeddings_dim,
504
+ output_dim=cross_attention_dim,
505
+ ff_mult=4,
506
+ )
507
+
508
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
509
+
510
+ x = self.proj(id_embeds)
511
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
512
+ x = self.norm(x)
513
+ out = self.perceiver_resampler(x, clip_embeds)
514
+ if shortcut:
515
+ out = scale * x + out
516
+ return out
517
+
518
+ class AttentionMLP(nn.Module):
519
+ def __init__(
520
+ self,
521
+ dtype=torch.float16,
522
+ dim=1024,
523
+ depth=8,
524
+ dim_head=64,
525
+ heads=16,
526
+ single_num_tokens=1,
527
+ embedding_dim=1280,
528
+ output_dim=768,
529
+ ff_mult=4,
530
+ max_seq_len: int = 257*2,
531
+ apply_pos_emb: bool = False,
532
+ num_latents_mean_pooled: int = 0,
533
+ ):
534
+ super().__init__()
535
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
536
+
537
+ self.single_num_tokens = single_num_tokens
538
+ self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5)
539
+
540
+ self.proj_in = nn.Linear(embedding_dim, dim)
541
+
542
+ self.proj_out = nn.Linear(dim, output_dim)
543
+ self.norm_out = nn.LayerNorm(output_dim)
544
+
545
+ self.to_latents_from_mean_pooled_seq = (
546
+ nn.Sequential(
547
+ nn.LayerNorm(dim),
548
+ nn.Linear(dim, dim * num_latents_mean_pooled),
549
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
550
+ )
551
+ if num_latents_mean_pooled > 0
552
+ else None
553
+ )
554
+
555
+ self.layers = nn.ModuleList([])
556
+ for _ in range(depth):
557
+ self.layers.append(
558
+ nn.ModuleList(
559
+ [
560
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
561
+ FeedForward(dim=dim, mult=ff_mult),
562
+ ]
563
+ )
564
+ )
565
+
566
+ def forward(self, x):
567
+ if self.pos_emb is not None:
568
+ n, device = x.shape[1], x.device
569
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
570
+ x = x + pos_emb
571
+ # x torch.Size([5, 257, 1280])
572
+ latents = self.latents.repeat(x.size(0), 1, 1)
573
+
574
+ x = self.proj_in(x) # torch.Size([5, 257, 1024])
575
+
576
+ if self.to_latents_from_mean_pooled_seq:
577
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
578
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
579
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
580
+
581
+ for attn, ff in self.layers:
582
+ latents = attn(x, latents) + latents
583
+ latents = ff(latents) + latents
584
+
585
+ latents = self.proj_out(latents)
586
+ return self.norm_out(latents)
587
+
588
+
589
+ def masked_mean(t, *, dim, mask=None):
590
+ if mask is None:
591
+ return t.mean(dim=dim)
592
+
593
+ denom = mask.sum(dim=dim, keepdim=True)
594
+ mask = rearrange(mask, "b n -> b n 1")
595
+ masked_t = t.masked_fill(~mask, 0.0)
596
+
597
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
598
+
599
+
images/templates/3f8d901770014c1b8f7f261971f0e92.png ADDED

Git LFS Details

  • SHA256: 4fa9319750b9927075934c40a180766e75ff539711293581dae6bac5963b9d05
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
images/templates/6577b962b6346df03fea83211daaf48.png ADDED
images/templates/75583964a834abe33b72f52b1a98e84.png ADDED

Git LFS Details

  • SHA256: 318c942eb3cc8a1f9320b2ea84a88cd95067785c07f8ae1dd18fe6c4cf8e8282
  • Pointer size: 132 Bytes
  • Size of remote file: 7.54 MB
images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg ADDED
models/BiSeNet/6.jpg ADDED
models/BiSeNet/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #__init__.py
2
+ # from BiSeNet.model import *
models/BiSeNet/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (198 Bytes). View file
 
models/BiSeNet/__pycache__/model.cpython-38.pyc ADDED
Binary file (9.18 kB). View file
 
models/BiSeNet/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
models/BiSeNet/evaluate.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from logger import setup_logger
5
+ from model import BiSeNet
6
+ from face_dataset import FaceMask
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader
11
+ import torch.nn.functional as F
12
+ import torch.distributed as dist
13
+
14
+ import os
15
+ import os.path as osp
16
+ import logging
17
+ import time
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+ import math
21
+ from PIL import Image
22
+ import torchvision.transforms as transforms
23
+ import cv2
24
+
25
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
26
+ # Colors for all 20 parts
27
+ part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
28
+ [255, 0, 85], [255, 0, 170],
29
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
30
+ [0, 255, 85], [0, 255, 170],
31
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
32
+ [0, 85, 255], [0, 170, 255],
33
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
34
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
35
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
36
+
37
+ im = np.array(im)
38
+ vis_im = im.copy().astype(np.uint8)
39
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
40
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
41
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
42
+
43
+ num_of_class = np.max(vis_parsing_anno)
44
+
45
+ for pi in range(1, num_of_class + 1):
46
+ index = np.where(vis_parsing_anno == pi)
47
+ vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
48
+
49
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
50
+ # print(vis_parsing_anno_color.shape, vis_im.shape)
51
+ vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
52
+
53
+ # Save result or not
54
+ if save_im:
55
+ cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
56
+
57
+ # return vis_im
58
+
59
+ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
60
+
61
+ if not os.path.exists(respth):
62
+ os.makedirs(respth)
63
+
64
+ n_classes = 19
65
+ net = BiSeNet(n_classes=n_classes)
66
+ net.cuda()
67
+ save_pth = osp.join('res/cp', cp)
68
+ net.load_state_dict(torch.load(save_pth))
69
+ net.eval()
70
+
71
+ to_tensor = transforms.Compose([
72
+ transforms.ToTensor(),
73
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
74
+ ])
75
+ with torch.no_grad():
76
+ for image_path in os.listdir(dspth):
77
+ img = Image.open(osp.join(dspth, image_path))
78
+ image = img.resize((512, 512), Image.BILINEAR)
79
+ img = to_tensor(image)
80
+ img = torch.unsqueeze(img, 0)
81
+ img = img.cuda()
82
+ out = net(img)[0]
83
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
84
+
85
+ vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+ if __name__ == "__main__":
94
+ setup_logger('./res')
95
+ evaluate()
models/BiSeNet/face_dataset.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import torchvision.transforms as transforms
7
+
8
+ import os.path as osp
9
+ import os
10
+ from PIL import Image
11
+ import numpy as np
12
+ import json
13
+ import cv2
14
+
15
+ from transform import *
16
+
17
+
18
+
19
+ class FaceMask(Dataset):
20
+ def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
21
+ super(FaceMask, self).__init__(*args, **kwargs)
22
+ assert mode in ('train', 'val', 'test')
23
+ self.mode = mode
24
+ self.ignore_lb = 255
25
+ self.rootpth = rootpth
26
+
27
+ self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
28
+
29
+ # pre-processing
30
+ self.to_tensor = transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
33
+ ])
34
+ self.trans_train = Compose([
35
+ ColorJitter(
36
+ brightness=0.5,
37
+ contrast=0.5,
38
+ saturation=0.5),
39
+ HorizontalFlip(),
40
+ RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
41
+ RandomCrop(cropsize)
42
+ ])
43
+
44
+ def __getitem__(self, idx):
45
+ impth = self.imgs[idx]
46
+ img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
47
+ img = img.resize((512, 512), Image.BILINEAR)
48
+ label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
49
+ # print(np.unique(np.array(label)))
50
+ if self.mode == 'train':
51
+ im_lb = dict(im=img, lb=label)
52
+ im_lb = self.trans_train(im_lb)
53
+ img, label = im_lb['im'], im_lb['lb']
54
+ img = self.to_tensor(img)
55
+ label = np.array(label).astype(np.int64)[np.newaxis, :]
56
+ return img, label
57
+
58
+ def __len__(self):
59
+ return len(self.imgs)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
64
+ face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
65
+ mask_path = '/home/zll/data/CelebAMask-HQ/mask'
66
+ counter = 0
67
+ total = 0
68
+ for i in range(15):
69
+ # files = os.listdir(osp.join(face_sep_mask, str(i)))
70
+
71
+ atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
72
+ 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
73
+
74
+ for j in range(i*2000, (i+1)*2000):
75
+
76
+ mask = np.zeros((512, 512))
77
+
78
+ for l, att in enumerate(atts, 1):
79
+ total += 1
80
+ file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
81
+ path = osp.join(face_sep_mask, str(i), file_name)
82
+
83
+ if os.path.exists(path):
84
+ counter += 1
85
+ sep_mask = np.array(Image.open(path).convert('P'))
86
+ # print(np.unique(sep_mask))
87
+
88
+ mask[sep_mask == 225] = l
89
+ cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
90
+ print(j)
91
+
92
+ print(counter, total)
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
models/BiSeNet/hair.png ADDED
models/BiSeNet/logger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import os.path as osp
6
+ import time
7
+ import sys
8
+ import logging
9
+
10
+ import torch.distributed as dist
11
+
12
+
13
+ def setup_logger(logpth):
14
+ logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
15
+ logfile = osp.join(logpth, logfile)
16
+ FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
17
+ log_level = logging.INFO
18
+ if dist.is_initialized() and not dist.get_rank()==0:
19
+ log_level = logging.ERROR
20
+ logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
21
+ logging.root.addHandler(logging.StreamHandler())
22
+
23
+
models/BiSeNet/loss.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import numpy as np
10
+
11
+
12
+ class OhemCELoss(nn.Module):
13
+ def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
14
+ super(OhemCELoss, self).__init__()
15
+ self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
16
+ self.n_min = n_min
17
+ self.ignore_lb = ignore_lb
18
+ self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
19
+
20
+ def forward(self, logits, labels):
21
+ N, C, H, W = logits.size()
22
+ loss = self.criteria(logits, labels).view(-1)
23
+ loss, _ = torch.sort(loss, descending=True)
24
+ if loss[self.n_min] > self.thresh:
25
+ loss = loss[loss>self.thresh]
26
+ else:
27
+ loss = loss[:self.n_min]
28
+ return torch.mean(loss)
29
+
30
+
31
+ class SoftmaxFocalLoss(nn.Module):
32
+ def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
33
+ super(SoftmaxFocalLoss, self).__init__()
34
+ self.gamma = gamma
35
+ self.nll = nn.NLLLoss(ignore_index=ignore_lb)
36
+
37
+ def forward(self, logits, labels):
38
+ scores = F.softmax(logits, dim=1)
39
+ factor = torch.pow(1.-scores, self.gamma)
40
+ log_score = F.log_softmax(logits, dim=1)
41
+ log_score = factor * log_score
42
+ loss = self.nll(log_score, labels)
43
+ return loss
44
+
45
+
46
+ if __name__ == '__main__':
47
+ torch.manual_seed(15)
48
+ criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
49
+ criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
50
+ net1 = nn.Sequential(
51
+ nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
52
+ )
53
+ net1.cuda()
54
+ net1.train()
55
+ net2 = nn.Sequential(
56
+ nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
57
+ )
58
+ net2.cuda()
59
+ net2.train()
60
+
61
+ with torch.no_grad():
62
+ inten = torch.randn(16, 3, 20, 20).cuda()
63
+ lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
64
+ lbs[1, :, :] = 255
65
+
66
+ logits1 = net1(inten)
67
+ logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
68
+ logits2 = net2(inten)
69
+ logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
70
+
71
+ loss1 = criteria1(logits1, lbs)
72
+ loss2 = criteria2(logits2, lbs)
73
+ loss = loss1 + loss2
74
+ print(loss.detach().cpu())
75
+ loss.backward()
models/BiSeNet/makeup.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+ from skimage.filters import gaussian
5
+
6
+
7
+ def sharpen(img):
8
+ img = img * 1.0
9
+ gauss_out = gaussian(img, sigma=5, multichannel=True)
10
+
11
+ alpha = 1.5
12
+ img_out = (img - gauss_out) * alpha + img
13
+
14
+ img_out = img_out / 255.0
15
+
16
+ mask_1 = img_out < 0
17
+ mask_2 = img_out > 1
18
+
19
+ img_out = img_out * (1 - mask_1)
20
+ img_out = img_out * (1 - mask_2) + mask_2
21
+ img_out = np.clip(img_out, 0, 1)
22
+ img_out = img_out * 255
23
+ return np.array(img_out, dtype=np.uint8)
24
+
25
+
26
+ def hair(image, parsing, part=17, color=[230, 50, 20]):
27
+ b, g, r = color #[10, 50, 250] # [10, 250, 10]
28
+ tar_color = np.zeros_like(image)
29
+ tar_color[:, :, 0] = b
30
+ tar_color[:, :, 1] = g
31
+ tar_color[:, :, 2] = r
32
+
33
+ image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
34
+ tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
35
+
36
+ if part == 12 or part == 13:
37
+ image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
38
+ else:
39
+ image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
40
+
41
+ changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
42
+
43
+ if part == 17:
44
+ changed = sharpen(changed)
45
+
46
+ changed[parsing != part] = image[parsing != part]
47
+ # changed = cv2.resize(changed, (512, 512))
48
+ return changed
49
+
50
+ #
51
+ # def lip(image, parsing, part=17, color=[230, 50, 20]):
52
+ # b, g, r = color #[10, 50, 250] # [10, 250, 10]
53
+ # tar_color = np.zeros_like(image)
54
+ # tar_color[:, :, 0] = b
55
+ # tar_color[:, :, 1] = g
56
+ # tar_color[:, :, 2] = r
57
+ #
58
+ # image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
59
+ # il, ia, ib = cv2.split(image_lab)
60
+ #
61
+ # tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab)
62
+ # tl, ta, tb = cv2.split(tar_lab)
63
+ #
64
+ # image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100)
65
+ # image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128)
66
+ # image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128)
67
+ #
68
+ #
69
+ # changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR)
70
+ #
71
+ # if part == 17:
72
+ # changed = sharpen(changed)
73
+ #
74
+ # changed[parsing != part] = image[parsing != part]
75
+ # # changed = cv2.resize(changed, (512, 512))
76
+ # return changed
77
+
78
+
79
+ if __name__ == '__main__':
80
+ # 1 face
81
+ # 10 nose
82
+ # 11 teeth
83
+ # 12 upper lip
84
+ # 13 lower lip
85
+ # 17 hair
86
+ num = 116
87
+ table = {
88
+ 'hair': 17,
89
+ 'upper_lip': 12,
90
+ 'lower_lip': 13
91
+ }
92
+ image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num)
93
+ parsing_path = 'res/test_res/{}.png'.format(num)
94
+
95
+ image = cv2.imread(image_path)
96
+ ori = image.copy()
97
+ parsing = np.array(cv2.imread(parsing_path, 0))
98
+ parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
99
+
100
+ parts = [table['hair'], table['upper_lip'], table['lower_lip']]
101
+ # colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]]
102
+ colors = [[100, 200, 100]]
103
+ for part, color in zip(parts, colors):
104
+ image = hair(image, parsing, part, color)
105
+ cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512)))
106
+ cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512)))
107
+
108
+ cv2.imshow('image', cv2.resize(ori, (512, 512)))
109
+ cv2.imshow('color', cv2.resize(image, (512, 512)))
110
+
111
+ # cv2.imshow('image', ori)
112
+ # cv2.imshow('color', image)
113
+
114
+ cv2.waitKey(0)
115
+ cv2.destroyAllWindows()
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
models/BiSeNet/makeup/116_1.png ADDED
models/BiSeNet/makeup/116_3.png ADDED
models/BiSeNet/makeup/116_lip_ori.png ADDED
models/BiSeNet/makeup/116_ori.png ADDED
models/BiSeNet/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
models/BiSeNet/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .bn import ABN, InPlaceABN, InPlaceABNSync
2
+ from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3
+ from .misc import GlobalAvgPool2d, SingleGPU
4
+ from .residual import IdentityResidualBlock
5
+ from .dense import DenseModule
models/BiSeNet/modules/bn.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ try:
6
+ from queue import Queue
7
+ except ImportError:
8
+ from Queue import Queue
9
+
10
+ from .functions import *
11
+
12
+
13
+ class ABN(nn.Module):
14
+ """Activated Batch Normalization
15
+
16
+ This gathers a `BatchNorm2d` and an activation function in a single module
17
+ """
18
+
19
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
20
+ """Creates an Activated Batch Normalization module
21
+
22
+ Parameters
23
+ ----------
24
+ num_features : int
25
+ Number of feature channels in the input and output.
26
+ eps : float
27
+ Small constant to prevent numerical issues.
28
+ momentum : float
29
+ Momentum factor applied to compute running statistics as.
30
+ affine : bool
31
+ If `True` apply learned scale and shift transformation after normalization.
32
+ activation : str
33
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
34
+ slope : float
35
+ Negative slope for the `leaky_relu` activation.
36
+ """
37
+ super(ABN, self).__init__()
38
+ self.num_features = num_features
39
+ self.affine = affine
40
+ self.eps = eps
41
+ self.momentum = momentum
42
+ self.activation = activation
43
+ self.slope = slope
44
+ if self.affine:
45
+ self.weight = nn.Parameter(torch.ones(num_features))
46
+ self.bias = nn.Parameter(torch.zeros(num_features))
47
+ else:
48
+ self.register_parameter('weight', None)
49
+ self.register_parameter('bias', None)
50
+ self.register_buffer('running_mean', torch.zeros(num_features))
51
+ self.register_buffer('running_var', torch.ones(num_features))
52
+ self.reset_parameters()
53
+
54
+ def reset_parameters(self):
55
+ nn.init.constant_(self.running_mean, 0)
56
+ nn.init.constant_(self.running_var, 1)
57
+ if self.affine:
58
+ nn.init.constant_(self.weight, 1)
59
+ nn.init.constant_(self.bias, 0)
60
+
61
+ def forward(self, x):
62
+ x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
63
+ self.training, self.momentum, self.eps)
64
+
65
+ if self.activation == ACT_RELU:
66
+ return functional.relu(x, inplace=True)
67
+ elif self.activation == ACT_LEAKY_RELU:
68
+ return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
69
+ elif self.activation == ACT_ELU:
70
+ return functional.elu(x, inplace=True)
71
+ else:
72
+ return x
73
+
74
+ def __repr__(self):
75
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
76
+ ' affine={affine}, activation={activation}'
77
+ if self.activation == "leaky_relu":
78
+ rep += ', slope={slope})'
79
+ else:
80
+ rep += ')'
81
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
82
+
83
+
84
+ class InPlaceABN(ABN):
85
+ """InPlace Activated Batch Normalization"""
86
+
87
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
88
+ """Creates an InPlace Activated Batch Normalization module
89
+
90
+ Parameters
91
+ ----------
92
+ num_features : int
93
+ Number of feature channels in the input and output.
94
+ eps : float
95
+ Small constant to prevent numerical issues.
96
+ momentum : float
97
+ Momentum factor applied to compute running statistics as.
98
+ affine : bool
99
+ If `True` apply learned scale and shift transformation after normalization.
100
+ activation : str
101
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
102
+ slope : float
103
+ Negative slope for the `leaky_relu` activation.
104
+ """
105
+ super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
106
+
107
+ def forward(self, x):
108
+ return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
109
+ self.training, self.momentum, self.eps, self.activation, self.slope)
110
+
111
+
112
+ class InPlaceABNSync(ABN):
113
+ """InPlace Activated Batch Normalization with cross-GPU synchronization
114
+ This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
115
+ """
116
+
117
+ def forward(self, x):
118
+ return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
119
+ self.training, self.momentum, self.eps, self.activation, self.slope)
120
+
121
+ def __repr__(self):
122
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
123
+ ' affine={affine}, activation={activation}'
124
+ if self.activation == "leaky_relu":
125
+ rep += ', slope={slope})'
126
+ else:
127
+ rep += ')'
128
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
129
+
130
+
models/BiSeNet/modules/deeplab.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ from models._util import try_index
6
+ from .bn import ABN
7
+
8
+
9
+ class DeeplabV3(nn.Module):
10
+ def __init__(self,
11
+ in_channels,
12
+ out_channels,
13
+ hidden_channels=256,
14
+ dilations=(12, 24, 36),
15
+ norm_act=ABN,
16
+ pooling_size=None):
17
+ super(DeeplabV3, self).__init__()
18
+ self.pooling_size = pooling_size
19
+
20
+ self.map_convs = nn.ModuleList([
21
+ nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
22
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
23
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
24
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
25
+ ])
26
+ self.map_bn = norm_act(hidden_channels * 4)
27
+
28
+ self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
29
+ self.global_pooling_bn = norm_act(hidden_channels)
30
+
31
+ self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
32
+ self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
33
+ self.red_bn = norm_act(out_channels)
34
+
35
+ self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
36
+
37
+ def reset_parameters(self, activation, slope):
38
+ gain = nn.init.calculate_gain(activation, slope)
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ nn.init.xavier_normal_(m.weight.data, gain)
42
+ if hasattr(m, "bias") and m.bias is not None:
43
+ nn.init.constant_(m.bias, 0)
44
+ elif isinstance(m, ABN):
45
+ if hasattr(m, "weight") and m.weight is not None:
46
+ nn.init.constant_(m.weight, 1)
47
+ if hasattr(m, "bias") and m.bias is not None:
48
+ nn.init.constant_(m.bias, 0)
49
+
50
+ def forward(self, x):
51
+ # Map convolutions
52
+ out = torch.cat([m(x) for m in self.map_convs], dim=1)
53
+ out = self.map_bn(out)
54
+ out = self.red_conv(out)
55
+
56
+ # Global pooling
57
+ pool = self._global_pooling(x)
58
+ pool = self.global_pooling_conv(pool)
59
+ pool = self.global_pooling_bn(pool)
60
+ pool = self.pool_red_conv(pool)
61
+ if self.training or self.pooling_size is None:
62
+ pool = pool.repeat(1, 1, x.size(2), x.size(3))
63
+
64
+ out += pool
65
+ out = self.red_bn(out)
66
+ return out
67
+
68
+ def _global_pooling(self, x):
69
+ if self.training or self.pooling_size is None:
70
+ pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
71
+ pool = pool.view(x.size(0), x.size(1), 1, 1)
72
+ else:
73
+ pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
74
+ min(try_index(self.pooling_size, 1), x.shape[3]))
75
+ padding = (
76
+ (pooling_size[1] - 1) // 2,
77
+ (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
78
+ (pooling_size[0] - 1) // 2,
79
+ (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
80
+ )
81
+
82
+ pool = functional.avg_pool2d(x, pooling_size, stride=1)
83
+ pool = functional.pad(pool, pad=padding, mode="replicate")
84
+ return pool
models/BiSeNet/modules/dense.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .bn import ABN
7
+
8
+
9
+ class DenseModule(nn.Module):
10
+ def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
11
+ super(DenseModule, self).__init__()
12
+ self.in_channels = in_channels
13
+ self.growth = growth
14
+ self.layers = layers
15
+
16
+ self.convs1 = nn.ModuleList()
17
+ self.convs3 = nn.ModuleList()
18
+ for i in range(self.layers):
19
+ self.convs1.append(nn.Sequential(OrderedDict([
20
+ ("bn", norm_act(in_channels)),
21
+ ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
22
+ ])))
23
+ self.convs3.append(nn.Sequential(OrderedDict([
24
+ ("bn", norm_act(self.growth * bottleneck_factor)),
25
+ ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
26
+ dilation=dilation))
27
+ ])))
28
+ in_channels += self.growth
29
+
30
+ @property
31
+ def out_channels(self):
32
+ return self.in_channels + self.growth * self.layers
33
+
34
+ def forward(self, x):
35
+ inputs = [x]
36
+ for i in range(self.layers):
37
+ x = torch.cat(inputs, dim=1)
38
+ x = self.convs1[i](x)
39
+ x = self.convs3[i](x)
40
+ inputs += [x]
41
+
42
+ return torch.cat(inputs, dim=1)
models/BiSeNet/modules/functions.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.autograd as autograd
5
+ import torch.cuda.comm as comm
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.utils.cpp_extension import load
8
+
9
+ _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
10
+ _backend = load(name="inplace_abn",
11
+ extra_cflags=["-O3"],
12
+ sources=[path.join(_src_path, f) for f in [
13
+ "inplace_abn.cpp",
14
+ "inplace_abn_cpu.cpp",
15
+ "inplace_abn_cuda.cu",
16
+ "inplace_abn_cuda_half.cu"
17
+ ]],
18
+ extra_cuda_cflags=["--expt-extended-lambda"])
19
+
20
+ # Activation names
21
+ ACT_RELU = "relu"
22
+ ACT_LEAKY_RELU = "leaky_relu"
23
+ ACT_ELU = "elu"
24
+ ACT_NONE = "none"
25
+
26
+
27
+ def _check(fn, *args, **kwargs):
28
+ success = fn(*args, **kwargs)
29
+ if not success:
30
+ raise RuntimeError("CUDA Error encountered in {}".format(fn))
31
+
32
+
33
+ def _broadcast_shape(x):
34
+ out_size = []
35
+ for i, s in enumerate(x.size()):
36
+ if i != 1:
37
+ out_size.append(1)
38
+ else:
39
+ out_size.append(s)
40
+ return out_size
41
+
42
+
43
+ def _reduce(x):
44
+ if len(x.size()) == 2:
45
+ return x.sum(dim=0)
46
+ else:
47
+ n, c = x.size()[0:2]
48
+ return x.contiguous().view((n, c, -1)).sum(2).sum(0)
49
+
50
+
51
+ def _count_samples(x):
52
+ count = 1
53
+ for i, s in enumerate(x.size()):
54
+ if i != 1:
55
+ count *= s
56
+ return count
57
+
58
+
59
+ def _act_forward(ctx, x):
60
+ if ctx.activation == ACT_LEAKY_RELU:
61
+ _backend.leaky_relu_forward(x, ctx.slope)
62
+ elif ctx.activation == ACT_ELU:
63
+ _backend.elu_forward(x)
64
+ elif ctx.activation == ACT_NONE:
65
+ pass
66
+
67
+
68
+ def _act_backward(ctx, x, dx):
69
+ if ctx.activation == ACT_LEAKY_RELU:
70
+ _backend.leaky_relu_backward(x, dx, ctx.slope)
71
+ elif ctx.activation == ACT_ELU:
72
+ _backend.elu_backward(x, dx)
73
+ elif ctx.activation == ACT_NONE:
74
+ pass
75
+
76
+
77
+ class InPlaceABN(autograd.Function):
78
+ @staticmethod
79
+ def forward(ctx, x, weight, bias, running_mean, running_var,
80
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
81
+ # Save context
82
+ ctx.training = training
83
+ ctx.momentum = momentum
84
+ ctx.eps = eps
85
+ ctx.activation = activation
86
+ ctx.slope = slope
87
+ ctx.affine = weight is not None and bias is not None
88
+
89
+ # Prepare inputs
90
+ count = _count_samples(x)
91
+ x = x.contiguous()
92
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
93
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
94
+
95
+ if ctx.training:
96
+ mean, var = _backend.mean_var(x)
97
+
98
+ # Update running stats
99
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
100
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
101
+
102
+ # Mark in-place modified tensors
103
+ ctx.mark_dirty(x, running_mean, running_var)
104
+ else:
105
+ mean, var = running_mean.contiguous(), running_var.contiguous()
106
+ ctx.mark_dirty(x)
107
+
108
+ # BN forward + activation
109
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
110
+ _act_forward(ctx, x)
111
+
112
+ # Output
113
+ ctx.var = var
114
+ ctx.save_for_backward(x, var, weight, bias)
115
+ return x
116
+
117
+ @staticmethod
118
+ @once_differentiable
119
+ def backward(ctx, dz):
120
+ z, var, weight, bias = ctx.saved_tensors
121
+ dz = dz.contiguous()
122
+
123
+ # Undo activation
124
+ _act_backward(ctx, z, dz)
125
+
126
+ if ctx.training:
127
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
128
+ else:
129
+ # TODO: implement simplified CUDA backward for inference mode
130
+ edz = dz.new_zeros(dz.size(1))
131
+ eydz = dz.new_zeros(dz.size(1))
132
+
133
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
134
+ dweight = eydz * weight.sign() if ctx.affine else None
135
+ dbias = edz if ctx.affine else None
136
+
137
+ return dx, dweight, dbias, None, None, None, None, None, None, None
138
+
139
+ class InPlaceABNSync(autograd.Function):
140
+ @classmethod
141
+ def forward(cls, ctx, x, weight, bias, running_mean, running_var,
142
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
143
+ # Save context
144
+ ctx.training = training
145
+ ctx.momentum = momentum
146
+ ctx.eps = eps
147
+ ctx.activation = activation
148
+ ctx.slope = slope
149
+ ctx.affine = weight is not None and bias is not None
150
+
151
+ # Prepare inputs
152
+ ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
153
+
154
+ #count = _count_samples(x)
155
+ batch_size = x.new_tensor([x.shape[0]],dtype=torch.long)
156
+
157
+ x = x.contiguous()
158
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
159
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
160
+
161
+ if ctx.training:
162
+ mean, var = _backend.mean_var(x)
163
+ if ctx.world_size>1:
164
+ # get global batch size
165
+ if equal_batches:
166
+ batch_size *= ctx.world_size
167
+ else:
168
+ dist.all_reduce(batch_size, dist.ReduceOp.SUM)
169
+
170
+ ctx.factor = x.shape[0]/float(batch_size.item())
171
+
172
+ mean_all = mean.clone() * ctx.factor
173
+ dist.all_reduce(mean_all, dist.ReduceOp.SUM)
174
+
175
+ var_all = (var + (mean - mean_all) ** 2) * ctx.factor
176
+ dist.all_reduce(var_all, dist.ReduceOp.SUM)
177
+
178
+ mean = mean_all
179
+ var = var_all
180
+
181
+ # Update running stats
182
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
183
+ count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1]
184
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
185
+
186
+ # Mark in-place modified tensors
187
+ ctx.mark_dirty(x, running_mean, running_var)
188
+ else:
189
+ mean, var = running_mean.contiguous(), running_var.contiguous()
190
+ ctx.mark_dirty(x)
191
+
192
+ # BN forward + activation
193
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
194
+ _act_forward(ctx, x)
195
+
196
+ # Output
197
+ ctx.var = var
198
+ ctx.save_for_backward(x, var, weight, bias)
199
+ return x
200
+
201
+ @staticmethod
202
+ @once_differentiable
203
+ def backward(ctx, dz):
204
+ z, var, weight, bias = ctx.saved_tensors
205
+ dz = dz.contiguous()
206
+
207
+ # Undo activation
208
+ _act_backward(ctx, z, dz)
209
+
210
+ if ctx.training:
211
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
212
+ edz_local = edz.clone()
213
+ eydz_local = eydz.clone()
214
+
215
+ if ctx.world_size>1:
216
+ edz *= ctx.factor
217
+ dist.all_reduce(edz, dist.ReduceOp.SUM)
218
+
219
+ eydz *= ctx.factor
220
+ dist.all_reduce(eydz, dist.ReduceOp.SUM)
221
+ else:
222
+ edz_local = edz = dz.new_zeros(dz.size(1))
223
+ eydz_local = eydz = dz.new_zeros(dz.size(1))
224
+
225
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
226
+ dweight = eydz_local * weight.sign() if ctx.affine else None
227
+ dbias = edz_local if ctx.affine else None
228
+
229
+ return dx, dweight, dbias, None, None, None, None, None, None, None
230
+
231
+ inplace_abn = InPlaceABN.apply
232
+ inplace_abn_sync = InPlaceABNSync.apply
233
+
234
+ __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
models/BiSeNet/modules/misc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ class GlobalAvgPool2d(nn.Module):
6
+ def __init__(self):
7
+ """Global average pooling over the input's spatial dimensions"""
8
+ super(GlobalAvgPool2d, self).__init__()
9
+
10
+ def forward(self, inputs):
11
+ in_size = inputs.size()
12
+ return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
13
+
14
+ class SingleGPU(nn.Module):
15
+ def __init__(self, module):
16
+ super(SingleGPU, self).__init__()
17
+ self.module=module
18
+
19
+ def forward(self, input):
20
+ return self.module(input.cuda(non_blocking=True))
21
+
models/BiSeNet/modules/residual.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch.nn as nn
4
+
5
+ from .bn import ABN
6
+
7
+
8
+ class IdentityResidualBlock(nn.Module):
9
+ def __init__(self,
10
+ in_channels,
11
+ channels,
12
+ stride=1,
13
+ dilation=1,
14
+ groups=1,
15
+ norm_act=ABN,
16
+ dropout=None):
17
+ """Configurable identity-mapping residual block
18
+
19
+ Parameters
20
+ ----------
21
+ in_channels : int
22
+ Number of input channels.
23
+ channels : list of int
24
+ Number of channels in the internal feature maps. Can either have two or three elements: if three construct
25
+ a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
26
+ `3 x 3` then `1 x 1` convolutions.
27
+ stride : int
28
+ Stride of the first `3 x 3` convolution
29
+ dilation : int
30
+ Dilation to apply to the `3 x 3` convolutions.
31
+ groups : int
32
+ Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
33
+ bottleneck blocks.
34
+ norm_act : callable
35
+ Function to create normalization / activation Module.
36
+ dropout: callable
37
+ Function to create Dropout Module.
38
+ """
39
+ super(IdentityResidualBlock, self).__init__()
40
+
41
+ # Check parameters for inconsistencies
42
+ if len(channels) != 2 and len(channels) != 3:
43
+ raise ValueError("channels must contain either two or three values")
44
+ if len(channels) == 2 and groups != 1:
45
+ raise ValueError("groups > 1 are only valid if len(channels) == 3")
46
+
47
+ is_bottleneck = len(channels) == 3
48
+ need_proj_conv = stride != 1 or in_channels != channels[-1]
49
+
50
+ self.bn1 = norm_act(in_channels)
51
+ if not is_bottleneck:
52
+ layers = [
53
+ ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
54
+ dilation=dilation)),
55
+ ("bn2", norm_act(channels[0])),
56
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
57
+ dilation=dilation))
58
+ ]
59
+ if dropout is not None:
60
+ layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
61
+ else:
62
+ layers = [
63
+ ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
64
+ ("bn2", norm_act(channels[0])),
65
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
66
+ groups=groups, dilation=dilation)),
67
+ ("bn3", norm_act(channels[1])),
68
+ ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
69
+ ]
70
+ if dropout is not None:
71
+ layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
72
+ self.convs = nn.Sequential(OrderedDict(layers))
73
+
74
+ if need_proj_conv:
75
+ self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
76
+
77
+ def forward(self, x):
78
+ if hasattr(self, "proj_conv"):
79
+ bn1 = self.bn1(x)
80
+ shortcut = self.proj_conv(bn1)
81
+ else:
82
+ shortcut = x.clone()
83
+ bn1 = self.bn1(x)
84
+
85
+ out = self.convs(bn1)
86
+ out.add_(shortcut)
87
+
88
+ return out
models/BiSeNet/modules/src/checks.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6
+ #ifndef AT_CHECK
7
+ #define AT_CHECK AT_ASSERT
8
+ #endif
9
+
10
+ #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12
+ #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13
+
14
+ #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+ #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
models/BiSeNet/modules/src/inplace_abn.cpp ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <vector>
4
+
5
+ #include "inplace_abn.h"
6
+
7
+ std::vector<at::Tensor> mean_var(at::Tensor x) {
8
+ if (x.is_cuda()) {
9
+ if (x.type().scalarType() == at::ScalarType::Half) {
10
+ return mean_var_cuda_h(x);
11
+ } else {
12
+ return mean_var_cuda(x);
13
+ }
14
+ } else {
15
+ return mean_var_cpu(x);
16
+ }
17
+ }
18
+
19
+ at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
20
+ bool affine, float eps) {
21
+ if (x.is_cuda()) {
22
+ if (x.type().scalarType() == at::ScalarType::Half) {
23
+ return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
24
+ } else {
25
+ return forward_cuda(x, mean, var, weight, bias, affine, eps);
26
+ }
27
+ } else {
28
+ return forward_cpu(x, mean, var, weight, bias, affine, eps);
29
+ }
30
+ }
31
+
32
+ std::vector<at::Tensor> edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
33
+ bool affine, float eps) {
34
+ if (z.is_cuda()) {
35
+ if (z.type().scalarType() == at::ScalarType::Half) {
36
+ return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
37
+ } else {
38
+ return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
39
+ }
40
+ } else {
41
+ return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
42
+ }
43
+ }
44
+
45
+ at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
46
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
47
+ if (z.is_cuda()) {
48
+ if (z.type().scalarType() == at::ScalarType::Half) {
49
+ return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
50
+ } else {
51
+ return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
52
+ }
53
+ } else {
54
+ return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
55
+ }
56
+ }
57
+
58
+ void leaky_relu_forward(at::Tensor z, float slope) {
59
+ at::leaky_relu_(z, slope);
60
+ }
61
+
62
+ void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
63
+ if (z.is_cuda()) {
64
+ if (z.type().scalarType() == at::ScalarType::Half) {
65
+ return leaky_relu_backward_cuda_h(z, dz, slope);
66
+ } else {
67
+ return leaky_relu_backward_cuda(z, dz, slope);
68
+ }
69
+ } else {
70
+ return leaky_relu_backward_cpu(z, dz, slope);
71
+ }
72
+ }
73
+
74
+ void elu_forward(at::Tensor z) {
75
+ at::elu_(z);
76
+ }
77
+
78
+ void elu_backward(at::Tensor z, at::Tensor dz) {
79
+ if (z.is_cuda()) {
80
+ return elu_backward_cuda(z, dz);
81
+ } else {
82
+ return elu_backward_cpu(z, dz);
83
+ }
84
+ }
85
+
86
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
87
+ m.def("mean_var", &mean_var, "Mean and variance computation");
88
+ m.def("forward", &forward, "In-place forward computation");
89
+ m.def("edz_eydz", &edz_eydz, "First part of backward computation");
90
+ m.def("backward", &backward, "Second part of backward computation");
91
+ m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
92
+ m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
93
+ m.def("elu_forward", &elu_forward, "Elu forward computation");
94
+ m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
95
+ }
models/BiSeNet/modules/src/inplace_abn.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ #include <vector>
6
+
7
+ std::vector<at::Tensor> mean_var_cpu(at::Tensor x);
8
+ std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
9
+ std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);
10
+
11
+ at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
12
+ bool affine, float eps);
13
+ at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
14
+ bool affine, float eps);
15
+ at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16
+ bool affine, float eps);
17
+
18
+ std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
19
+ bool affine, float eps);
20
+ std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
21
+ bool affine, float eps);
22
+ std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
23
+ bool affine, float eps);
24
+
25
+ at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
26
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
27
+ at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
28
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
29
+ at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
30
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
31
+
32
+ void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
33
+ void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
34
+ void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
35
+
36
+ void elu_backward_cpu(at::Tensor z, at::Tensor dz);
37
+ void elu_backward_cuda(at::Tensor z, at::Tensor dz);
38
+
39
+ static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
40
+ num = x.size(0);
41
+ chn = x.size(1);
42
+ sp = 1;
43
+ for (int64_t i = 2; i < x.ndimension(); ++i)
44
+ sp *= x.size(i);
45
+ }
46
+
47
+ /*
48
+ * Specialized CUDA reduction functions for BN
49
+ */
50
+ #ifdef __CUDACC__
51
+
52
+ #include "utils/cuda.cuh"
53
+
54
+ template <typename T, typename Op>
55
+ __device__ T reduce(Op op, int plane, int N, int S) {
56
+ T sum = (T)0;
57
+ for (int batch = 0; batch < N; ++batch) {
58
+ for (int x = threadIdx.x; x < S; x += blockDim.x) {
59
+ sum += op(batch, plane, x);
60
+ }
61
+ }
62
+
63
+ // sum over NumThreads within a warp
64
+ sum = warpSum(sum);
65
+
66
+ // 'transpose', and reduce within warp again
67
+ __shared__ T shared[32];
68
+ __syncthreads();
69
+ if (threadIdx.x % WARP_SIZE == 0) {
70
+ shared[threadIdx.x / WARP_SIZE] = sum;
71
+ }
72
+ if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
73
+ // zero out the other entries in shared
74
+ shared[threadIdx.x] = (T)0;
75
+ }
76
+ __syncthreads();
77
+ if (threadIdx.x / WARP_SIZE == 0) {
78
+ sum = warpSum(shared[threadIdx.x]);
79
+ if (threadIdx.x == 0) {
80
+ shared[0] = sum;
81
+ }
82
+ }
83
+ __syncthreads();
84
+
85
+ // Everyone picks it up, should be broadcast into the whole gradInput
86
+ return shared[0];
87
+ }
88
+ #endif
models/BiSeNet/modules/src/inplace_abn_cpu.cpp ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <vector>
4
+
5
+ #include "utils/checks.h"
6
+ #include "inplace_abn.h"
7
+
8
+ at::Tensor reduce_sum(at::Tensor x) {
9
+ if (x.ndimension() == 2) {
10
+ return x.sum(0);
11
+ } else {
12
+ auto x_view = x.view({x.size(0), x.size(1), -1});
13
+ return x_view.sum(-1).sum(0);
14
+ }
15
+ }
16
+
17
+ at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
18
+ if (x.ndimension() == 2) {
19
+ return v;
20
+ } else {
21
+ std::vector<int64_t> broadcast_size = {1, -1};
22
+ for (int64_t i = 2; i < x.ndimension(); ++i)
23
+ broadcast_size.push_back(1);
24
+
25
+ return v.view(broadcast_size);
26
+ }
27
+ }
28
+
29
+ int64_t count(at::Tensor x) {
30
+ int64_t count = x.size(0);
31
+ for (int64_t i = 2; i < x.ndimension(); ++i)
32
+ count *= x.size(i);
33
+
34
+ return count;
35
+ }
36
+
37
+ at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
38
+ if (affine) {
39
+ return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
40
+ } else {
41
+ return z;
42
+ }
43
+ }
44
+
45
+ std::vector<at::Tensor> mean_var_cpu(at::Tensor x) {
46
+ auto num = count(x);
47
+ auto mean = reduce_sum(x) / num;
48
+ auto diff = x - broadcast_to(mean, x);
49
+ auto var = reduce_sum(diff.pow(2)) / num;
50
+
51
+ return {mean, var};
52
+ }
53
+
54
+ at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
55
+ bool affine, float eps) {
56
+ auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
57
+ auto mul = at::rsqrt(var + eps) * gamma;
58
+
59
+ x.sub_(broadcast_to(mean, x));
60
+ x.mul_(broadcast_to(mul, x));
61
+ if (affine) x.add_(broadcast_to(bias, x));
62
+
63
+ return x;
64
+ }
65
+
66
+ std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
67
+ bool affine, float eps) {
68
+ auto edz = reduce_sum(dz);
69
+ auto y = invert_affine(z, weight, bias, affine, eps);
70
+ auto eydz = reduce_sum(y * dz);
71
+
72
+ return {edz, eydz};
73
+ }
74
+
75
+ at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
76
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
77
+ auto y = invert_affine(z, weight, bias, affine, eps);
78
+ auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
79
+
80
+ auto num = count(z);
81
+ auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
82
+ return dx;
83
+ }
84
+
85
+ void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
86
+ CHECK_CPU_INPUT(z);
87
+ CHECK_CPU_INPUT(dz);
88
+
89
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
90
+ int64_t count = z.numel();
91
+ auto *_z = z.data<scalar_t>();
92
+ auto *_dz = dz.data<scalar_t>();
93
+
94
+ for (int64_t i = 0; i < count; ++i) {
95
+ if (_z[i] < 0) {
96
+ _z[i] *= 1 / slope;
97
+ _dz[i] *= slope;
98
+ }
99
+ }
100
+ }));
101
+ }
102
+
103
+ void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
104
+ CHECK_CPU_INPUT(z);
105
+ CHECK_CPU_INPUT(dz);
106
+
107
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
108
+ int64_t count = z.numel();
109
+ auto *_z = z.data<scalar_t>();
110
+ auto *_dz = dz.data<scalar_t>();
111
+
112
+ for (int64_t i = 0; i < count; ++i) {
113
+ if (_z[i] < 0) {
114
+ _z[i] = log1p(_z[i]);
115
+ _dz[i] *= (_z[i] + 1.f);
116
+ }
117
+ }
118
+ }));
119
+ }
models/BiSeNet/modules/src/inplace_abn_cuda.cu ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <thrust/device_ptr.h>
4
+ #include <thrust/transform.h>
5
+
6
+ #include <vector>
7
+
8
+ #include "utils/checks.h"
9
+ #include "utils/cuda.cuh"
10
+ #include "inplace_abn.h"
11
+
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ // Operations for reduce
15
+ template<typename T>
16
+ struct SumOp {
17
+ __device__ SumOp(const T *t, int c, int s)
18
+ : tensor(t), chn(c), sp(s) {}
19
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
20
+ return tensor[(batch * chn + plane) * sp + n];
21
+ }
22
+ const T *tensor;
23
+ const int chn;
24
+ const int sp;
25
+ };
26
+
27
+ template<typename T>
28
+ struct VarOp {
29
+ __device__ VarOp(T m, const T *t, int c, int s)
30
+ : mean(m), tensor(t), chn(c), sp(s) {}
31
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
32
+ T val = tensor[(batch * chn + plane) * sp + n];
33
+ return (val - mean) * (val - mean);
34
+ }
35
+ const T mean;
36
+ const T *tensor;
37
+ const int chn;
38
+ const int sp;
39
+ };
40
+
41
+ template<typename T>
42
+ struct GradOp {
43
+ __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
44
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
45
+ __device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
46
+ T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
47
+ T _dz = dz[(batch * chn + plane) * sp + n];
48
+ return Pair<T>(_dz, _y * _dz);
49
+ }
50
+ const T weight;
51
+ const T bias;
52
+ const T *z;
53
+ const T *dz;
54
+ const int chn;
55
+ const int sp;
56
+ };
57
+
58
+ /***********
59
+ * mean_var
60
+ ***********/
61
+
62
+ template<typename T>
63
+ __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
64
+ int plane = blockIdx.x;
65
+ T norm = T(1) / T(num * sp);
66
+
67
+ T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
68
+ __syncthreads();
69
+ T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
70
+
71
+ if (threadIdx.x == 0) {
72
+ mean[plane] = _mean;
73
+ var[plane] = _var;
74
+ }
75
+ }
76
+
77
+ std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
78
+ CHECK_CUDA_INPUT(x);
79
+
80
+ // Extract dimensions
81
+ int64_t num, chn, sp;
82
+ get_dims(x, num, chn, sp);
83
+
84
+ // Prepare output tensors
85
+ auto mean = at::empty({chn}, x.options());
86
+ auto var = at::empty({chn}, x.options());
87
+
88
+ // Run kernel
89
+ dim3 blocks(chn);
90
+ dim3 threads(getNumThreads(sp));
91
+ auto stream = at::cuda::getCurrentCUDAStream();
92
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
93
+ mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
94
+ x.data<scalar_t>(),
95
+ mean.data<scalar_t>(),
96
+ var.data<scalar_t>(),
97
+ num, chn, sp);
98
+ }));
99
+
100
+ return {mean, var};
101
+ }
102
+
103
+ /**********
104
+ * forward
105
+ **********/
106
+
107
+ template<typename T>
108
+ __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
109
+ bool affine, float eps, int num, int chn, int sp) {
110
+ int plane = blockIdx.x;
111
+
112
+ T _mean = mean[plane];
113
+ T _var = var[plane];
114
+ T _weight = affine ? abs(weight[plane]) + eps : T(1);
115
+ T _bias = affine ? bias[plane] : T(0);
116
+
117
+ T mul = rsqrt(_var + eps) * _weight;
118
+
119
+ for (int batch = 0; batch < num; ++batch) {
120
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
121
+ T _x = x[(batch * chn + plane) * sp + n];
122
+ T _y = (_x - _mean) * mul + _bias;
123
+
124
+ x[(batch * chn + plane) * sp + n] = _y;
125
+ }
126
+ }
127
+ }
128
+
129
+ at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
130
+ bool affine, float eps) {
131
+ CHECK_CUDA_INPUT(x);
132
+ CHECK_CUDA_INPUT(mean);
133
+ CHECK_CUDA_INPUT(var);
134
+ CHECK_CUDA_INPUT(weight);
135
+ CHECK_CUDA_INPUT(bias);
136
+
137
+ // Extract dimensions
138
+ int64_t num, chn, sp;
139
+ get_dims(x, num, chn, sp);
140
+
141
+ // Run kernel
142
+ dim3 blocks(chn);
143
+ dim3 threads(getNumThreads(sp));
144
+ auto stream = at::cuda::getCurrentCUDAStream();
145
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
146
+ forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
147
+ x.data<scalar_t>(),
148
+ mean.data<scalar_t>(),
149
+ var.data<scalar_t>(),
150
+ weight.data<scalar_t>(),
151
+ bias.data<scalar_t>(),
152
+ affine, eps, num, chn, sp);
153
+ }));
154
+
155
+ return x;
156
+ }
157
+
158
+ /***********
159
+ * edz_eydz
160
+ ***********/
161
+
162
+ template<typename T>
163
+ __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
164
+ T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
165
+ int plane = blockIdx.x;
166
+
167
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
168
+ T _bias = affine ? bias[plane] : 0.f;
169
+
170
+ Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
171
+ __syncthreads();
172
+
173
+ if (threadIdx.x == 0) {
174
+ edz[plane] = res.v1;
175
+ eydz[plane] = res.v2;
176
+ }
177
+ }
178
+
179
+ std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
180
+ bool affine, float eps) {
181
+ CHECK_CUDA_INPUT(z);
182
+ CHECK_CUDA_INPUT(dz);
183
+ CHECK_CUDA_INPUT(weight);
184
+ CHECK_CUDA_INPUT(bias);
185
+
186
+ // Extract dimensions
187
+ int64_t num, chn, sp;
188
+ get_dims(z, num, chn, sp);
189
+
190
+ auto edz = at::empty({chn}, z.options());
191
+ auto eydz = at::empty({chn}, z.options());
192
+
193
+ // Run kernel
194
+ dim3 blocks(chn);
195
+ dim3 threads(getNumThreads(sp));
196
+ auto stream = at::cuda::getCurrentCUDAStream();
197
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
198
+ edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
199
+ z.data<scalar_t>(),
200
+ dz.data<scalar_t>(),
201
+ weight.data<scalar_t>(),
202
+ bias.data<scalar_t>(),
203
+ edz.data<scalar_t>(),
204
+ eydz.data<scalar_t>(),
205
+ affine, eps, num, chn, sp);
206
+ }));
207
+
208
+ return {edz, eydz};
209
+ }
210
+
211
+ /***********
212
+ * backward
213
+ ***********/
214
+
215
+ template<typename T>
216
+ __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
217
+ const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
218
+ int plane = blockIdx.x;
219
+
220
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
221
+ T _bias = affine ? bias[plane] : 0.f;
222
+ T _var = var[plane];
223
+ T _edz = edz[plane];
224
+ T _eydz = eydz[plane];
225
+
226
+ T _mul = _weight * rsqrt(_var + eps);
227
+ T count = T(num * sp);
228
+
229
+ for (int batch = 0; batch < num; ++batch) {
230
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
231
+ T _dz = dz[(batch * chn + plane) * sp + n];
232
+ T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
233
+
234
+ dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
235
+ }
236
+ }
237
+ }
238
+
239
+ at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
240
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
241
+ CHECK_CUDA_INPUT(z);
242
+ CHECK_CUDA_INPUT(dz);
243
+ CHECK_CUDA_INPUT(var);
244
+ CHECK_CUDA_INPUT(weight);
245
+ CHECK_CUDA_INPUT(bias);
246
+ CHECK_CUDA_INPUT(edz);
247
+ CHECK_CUDA_INPUT(eydz);
248
+
249
+ // Extract dimensions
250
+ int64_t num, chn, sp;
251
+ get_dims(z, num, chn, sp);
252
+
253
+ auto dx = at::zeros_like(z);
254
+
255
+ // Run kernel
256
+ dim3 blocks(chn);
257
+ dim3 threads(getNumThreads(sp));
258
+ auto stream = at::cuda::getCurrentCUDAStream();
259
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
260
+ backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
261
+ z.data<scalar_t>(),
262
+ dz.data<scalar_t>(),
263
+ var.data<scalar_t>(),
264
+ weight.data<scalar_t>(),
265
+ bias.data<scalar_t>(),
266
+ edz.data<scalar_t>(),
267
+ eydz.data<scalar_t>(),
268
+ dx.data<scalar_t>(),
269
+ affine, eps, num, chn, sp);
270
+ }));
271
+
272
+ return dx;
273
+ }
274
+
275
+ /**************
276
+ * activations
277
+ **************/
278
+
279
+ template<typename T>
280
+ inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
281
+ // Create thrust pointers
282
+ thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
283
+ thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
284
+
285
+ auto stream = at::cuda::getCurrentCUDAStream();
286
+ thrust::transform_if(thrust::cuda::par.on(stream),
287
+ th_dz, th_dz + count, th_z, th_dz,
288
+ [slope] __device__ (const T& dz) { return dz * slope; },
289
+ [] __device__ (const T& z) { return z < 0; });
290
+ thrust::transform_if(thrust::cuda::par.on(stream),
291
+ th_z, th_z + count, th_z,
292
+ [slope] __device__ (const T& z) { return z / slope; },
293
+ [] __device__ (const T& z) { return z < 0; });
294
+ }
295
+
296
+ void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
297
+ CHECK_CUDA_INPUT(z);
298
+ CHECK_CUDA_INPUT(dz);
299
+
300
+ int64_t count = z.numel();
301
+
302
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
303
+ leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
304
+ }));
305
+ }
306
+
307
+ template<typename T>
308
+ inline void elu_backward_impl(T *z, T *dz, int64_t count) {
309
+ // Create thrust pointers
310
+ thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
311
+ thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
312
+
313
+ auto stream = at::cuda::getCurrentCUDAStream();
314
+ thrust::transform_if(thrust::cuda::par.on(stream),
315
+ th_dz, th_dz + count, th_z, th_z, th_dz,
316
+ [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
317
+ [] __device__ (const T& z) { return z < 0; });
318
+ thrust::transform_if(thrust::cuda::par.on(stream),
319
+ th_z, th_z + count, th_z,
320
+ [] __device__ (const T& z) { return log1p(z); },
321
+ [] __device__ (const T& z) { return z < 0; });
322
+ }
323
+
324
+ void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
325
+ CHECK_CUDA_INPUT(z);
326
+ CHECK_CUDA_INPUT(dz);
327
+
328
+ int64_t count = z.numel();
329
+
330
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
331
+ elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
332
+ }));
333
+ }
models/BiSeNet/modules/src/inplace_abn_cuda_half.cu ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <cuda_fp16.h>
4
+
5
+ #include <vector>
6
+
7
+ #include "utils/checks.h"
8
+ #include "utils/cuda.cuh"
9
+ #include "inplace_abn.h"
10
+
11
+ #include <ATen/cuda/CUDAContext.h>
12
+
13
+ // Operations for reduce
14
+ struct SumOpH {
15
+ __device__ SumOpH(const half *t, int c, int s)
16
+ : tensor(t), chn(c), sp(s) {}
17
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
18
+ return __half2float(tensor[(batch * chn + plane) * sp + n]);
19
+ }
20
+ const half *tensor;
21
+ const int chn;
22
+ const int sp;
23
+ };
24
+
25
+ struct VarOpH {
26
+ __device__ VarOpH(float m, const half *t, int c, int s)
27
+ : mean(m), tensor(t), chn(c), sp(s) {}
28
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
29
+ const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
30
+ return (t - mean) * (t - mean);
31
+ }
32
+ const float mean;
33
+ const half *tensor;
34
+ const int chn;
35
+ const int sp;
36
+ };
37
+
38
+ struct GradOpH {
39
+ __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
40
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
41
+ __device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) {
42
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
43
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
44
+ return Pair<float>(_dz, _y * _dz);
45
+ }
46
+ const float weight;
47
+ const float bias;
48
+ const half *z;
49
+ const half *dz;
50
+ const int chn;
51
+ const int sp;
52
+ };
53
+
54
+ /***********
55
+ * mean_var
56
+ ***********/
57
+
58
+ __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
59
+ int plane = blockIdx.x;
60
+ float norm = 1.f / static_cast<float>(num * sp);
61
+
62
+ float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm;
63
+ __syncthreads();
64
+ float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
65
+
66
+ if (threadIdx.x == 0) {
67
+ mean[plane] = _mean;
68
+ var[plane] = _var;
69
+ }
70
+ }
71
+
72
+ std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) {
73
+ CHECK_CUDA_INPUT(x);
74
+
75
+ // Extract dimensions
76
+ int64_t num, chn, sp;
77
+ get_dims(x, num, chn, sp);
78
+
79
+ // Prepare output tensors
80
+ auto mean = at::empty({chn},x.options().dtype(at::kFloat));
81
+ auto var = at::empty({chn},x.options().dtype(at::kFloat));
82
+
83
+ // Run kernel
84
+ dim3 blocks(chn);
85
+ dim3 threads(getNumThreads(sp));
86
+ auto stream = at::cuda::getCurrentCUDAStream();
87
+ mean_var_kernel_h<<<blocks, threads, 0, stream>>>(
88
+ reinterpret_cast<half*>(x.data<at::Half>()),
89
+ mean.data<float>(),
90
+ var.data<float>(),
91
+ num, chn, sp);
92
+
93
+ return {mean, var};
94
+ }
95
+
96
+ /**********
97
+ * forward
98
+ **********/
99
+
100
+ __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
101
+ bool affine, float eps, int num, int chn, int sp) {
102
+ int plane = blockIdx.x;
103
+
104
+ const float _mean = mean[plane];
105
+ const float _var = var[plane];
106
+ const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
107
+ const float _bias = affine ? bias[plane] : 0.f;
108
+
109
+ const float mul = rsqrt(_var + eps) * _weight;
110
+
111
+ for (int batch = 0; batch < num; ++batch) {
112
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
113
+ half *x_ptr = x + (batch * chn + plane) * sp + n;
114
+ float _x = __half2float(*x_ptr);
115
+ float _y = (_x - _mean) * mul + _bias;
116
+
117
+ *x_ptr = __float2half(_y);
118
+ }
119
+ }
120
+ }
121
+
122
+ at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
123
+ bool affine, float eps) {
124
+ CHECK_CUDA_INPUT(x);
125
+ CHECK_CUDA_INPUT(mean);
126
+ CHECK_CUDA_INPUT(var);
127
+ CHECK_CUDA_INPUT(weight);
128
+ CHECK_CUDA_INPUT(bias);
129
+
130
+ // Extract dimensions
131
+ int64_t num, chn, sp;
132
+ get_dims(x, num, chn, sp);
133
+
134
+ // Run kernel
135
+ dim3 blocks(chn);
136
+ dim3 threads(getNumThreads(sp));
137
+ auto stream = at::cuda::getCurrentCUDAStream();
138
+ forward_kernel_h<<<blocks, threads, 0, stream>>>(
139
+ reinterpret_cast<half*>(x.data<at::Half>()),
140
+ mean.data<float>(),
141
+ var.data<float>(),
142
+ weight.data<float>(),
143
+ bias.data<float>(),
144
+ affine, eps, num, chn, sp);
145
+
146
+ return x;
147
+ }
148
+
149
+ __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
150
+ float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
151
+ int plane = blockIdx.x;
152
+
153
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
154
+ float _bias = affine ? bias[plane] : 0.f;
155
+
156
+ Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
157
+ __syncthreads();
158
+
159
+ if (threadIdx.x == 0) {
160
+ edz[plane] = res.v1;
161
+ eydz[plane] = res.v2;
162
+ }
163
+ }
164
+
165
+ std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
166
+ bool affine, float eps) {
167
+ CHECK_CUDA_INPUT(z);
168
+ CHECK_CUDA_INPUT(dz);
169
+ CHECK_CUDA_INPUT(weight);
170
+ CHECK_CUDA_INPUT(bias);
171
+
172
+ // Extract dimensions
173
+ int64_t num, chn, sp;
174
+ get_dims(z, num, chn, sp);
175
+
176
+ auto edz = at::empty({chn},z.options().dtype(at::kFloat));
177
+ auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
178
+
179
+ // Run kernel
180
+ dim3 blocks(chn);
181
+ dim3 threads(getNumThreads(sp));
182
+ auto stream = at::cuda::getCurrentCUDAStream();
183
+ edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>(
184
+ reinterpret_cast<half*>(z.data<at::Half>()),
185
+ reinterpret_cast<half*>(dz.data<at::Half>()),
186
+ weight.data<float>(),
187
+ bias.data<float>(),
188
+ edz.data<float>(),
189
+ eydz.data<float>(),
190
+ affine, eps, num, chn, sp);
191
+
192
+ return {edz, eydz};
193
+ }
194
+
195
+ __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
196
+ const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
197
+ int plane = blockIdx.x;
198
+
199
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
200
+ float _bias = affine ? bias[plane] : 0.f;
201
+ float _var = var[plane];
202
+ float _edz = edz[plane];
203
+ float _eydz = eydz[plane];
204
+
205
+ float _mul = _weight * rsqrt(_var + eps);
206
+ float count = float(num * sp);
207
+
208
+ for (int batch = 0; batch < num; ++batch) {
209
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
210
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
211
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
212
+
213
+ dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
214
+ }
215
+ }
216
+ }
217
+
218
+ at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
219
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
220
+ CHECK_CUDA_INPUT(z);
221
+ CHECK_CUDA_INPUT(dz);
222
+ CHECK_CUDA_INPUT(var);
223
+ CHECK_CUDA_INPUT(weight);
224
+ CHECK_CUDA_INPUT(bias);
225
+ CHECK_CUDA_INPUT(edz);
226
+ CHECK_CUDA_INPUT(eydz);
227
+
228
+ // Extract dimensions
229
+ int64_t num, chn, sp;
230
+ get_dims(z, num, chn, sp);
231
+
232
+ auto dx = at::zeros_like(z);
233
+
234
+ // Run kernel
235
+ dim3 blocks(chn);
236
+ dim3 threads(getNumThreads(sp));
237
+ auto stream = at::cuda::getCurrentCUDAStream();
238
+ backward_kernel_h<<<blocks, threads, 0, stream>>>(
239
+ reinterpret_cast<half*>(z.data<at::Half>()),
240
+ reinterpret_cast<half*>(dz.data<at::Half>()),
241
+ var.data<float>(),
242
+ weight.data<float>(),
243
+ bias.data<float>(),
244
+ edz.data<float>(),
245
+ eydz.data<float>(),
246
+ reinterpret_cast<half*>(dx.data<at::Half>()),
247
+ affine, eps, num, chn, sp);
248
+
249
+ return dx;
250
+ }
251
+
252
+ __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
253
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
254
+ float _z = __half2float(z[i]);
255
+ if (_z < 0) {
256
+ dz[i] = __float2half(__half2float(dz[i]) * slope);
257
+ z[i] = __float2half(_z / slope);
258
+ }
259
+ }
260
+ }
261
+
262
+ void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
263
+ CHECK_CUDA_INPUT(z);
264
+ CHECK_CUDA_INPUT(dz);
265
+
266
+ int64_t count = z.numel();
267
+ dim3 threads(getNumThreads(count));
268
+ dim3 blocks = (count + threads.x - 1) / threads.x;
269
+ auto stream = at::cuda::getCurrentCUDAStream();
270
+ leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>(
271
+ reinterpret_cast<half*>(z.data<at::Half>()),
272
+ reinterpret_cast<half*>(dz.data<at::Half>()),
273
+ slope, count);
274
+ }
275
+
models/BiSeNet/modules/src/utils/checks.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6
+ #ifndef AT_CHECK
7
+ #define AT_CHECK AT_ASSERT
8
+ #endif
9
+
10
+ #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12
+ #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13
+
14
+ #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+ #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
models/BiSeNet/modules/src/utils/common.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ /*
6
+ * Functions to share code between CPU and GPU
7
+ */
8
+
9
+ #ifdef __CUDACC__
10
+ // CUDA versions
11
+
12
+ #define HOST_DEVICE __host__ __device__
13
+ #define INLINE_HOST_DEVICE __host__ __device__ inline
14
+ #define FLOOR(x) floor(x)
15
+
16
+ #if __CUDA_ARCH__ >= 600
17
+ // Recent compute capabilities have block-level atomicAdd for all data types, so we use that
18
+ #define ACCUM(x,y) atomicAdd_block(&(x),(y))
19
+ #else
20
+ // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
21
+ // and use the known atomicCAS-based implementation for double
22
+ template<typename data_t>
23
+ __device__ inline data_t atomic_add(data_t *address, data_t val) {
24
+ return atomicAdd(address, val);
25
+ }
26
+
27
+ template<>
28
+ __device__ inline double atomic_add(double *address, double val) {
29
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
30
+ unsigned long long int old = *address_as_ull, assumed;
31
+ do {
32
+ assumed = old;
33
+ old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
34
+ } while (assumed != old);
35
+ return __longlong_as_double(old);
36
+ }
37
+
38
+ #define ACCUM(x,y) atomic_add(&(x),(y))
39
+ #endif // #if __CUDA_ARCH__ >= 600
40
+
41
+ #else
42
+ // CPU versions
43
+
44
+ #define HOST_DEVICE
45
+ #define INLINE_HOST_DEVICE inline
46
+ #define FLOOR(x) std::floor(x)
47
+ #define ACCUM(x,y) (x) += (y)
48
+
49
+ #endif // #ifdef __CUDACC__
models/BiSeNet/modules/src/utils/cuda.cuh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /*
4
+ * General settings and functions
5
+ */
6
+ const int WARP_SIZE = 32;
7
+ const int MAX_BLOCK_SIZE = 1024;
8
+
9
+ static int getNumThreads(int nElem) {
10
+ int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
11
+ for (int i = 0; i < 6; ++i) {
12
+ if (nElem <= threadSizes[i]) {
13
+ return threadSizes[i];
14
+ }
15
+ }
16
+ return MAX_BLOCK_SIZE;
17
+ }
18
+
19
+ /*
20
+ * Reduction utilities
21
+ */
22
+ template <typename T>
23
+ __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
24
+ unsigned int mask = 0xffffffff) {
25
+ #if CUDART_VERSION >= 9000
26
+ return __shfl_xor_sync(mask, value, laneMask, width);
27
+ #else
28
+ return __shfl_xor(value, laneMask, width);
29
+ #endif
30
+ }
31
+
32
+ __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
33
+
34
+ template<typename T>
35
+ struct Pair {
36
+ T v1, v2;
37
+ __device__ Pair() {}
38
+ __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
39
+ __device__ Pair(T v) : v1(v), v2(v) {}
40
+ __device__ Pair(int v) : v1(v), v2(v) {}
41
+ __device__ Pair &operator+=(const Pair<T> &a) {
42
+ v1 += a.v1;
43
+ v2 += a.v2;
44
+ return *this;
45
+ }
46
+ };
47
+
48
+ template<typename T>
49
+ static __device__ __forceinline__ T warpSum(T val) {
50
+ #if __CUDA_ARCH__ >= 300
51
+ for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
52
+ val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
53
+ }
54
+ #else
55
+ __shared__ T values[MAX_BLOCK_SIZE];
56
+ values[threadIdx.x] = val;
57
+ __threadfence_block();
58
+ const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
59
+ for (int i = 1; i < WARP_SIZE; i++) {
60
+ val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
61
+ }
62
+ #endif
63
+ return val;
64
+ }
65
+
66
+ template<typename T>
67
+ static __device__ __forceinline__ Pair<T> warpSum(Pair<T> value) {
68
+ value.v1 = warpSum(value.v1);
69
+ value.v2 = warpSum(value.v2);
70
+ return value;
71
+ }
models/BiSeNet/optimizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import logging
7
+
8
+ logger = logging.getLogger()
9
+
10
+ class Optimizer(object):
11
+ def __init__(self,
12
+ model,
13
+ lr0,
14
+ momentum,
15
+ wd,
16
+ warmup_steps,
17
+ warmup_start_lr,
18
+ max_iter,
19
+ power,
20
+ *args, **kwargs):
21
+ self.warmup_steps = warmup_steps
22
+ self.warmup_start_lr = warmup_start_lr
23
+ self.lr0 = lr0
24
+ self.lr = self.lr0
25
+ self.max_iter = float(max_iter)
26
+ self.power = power
27
+ self.it = 0
28
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
29
+ param_list = [
30
+ {'params': wd_params},
31
+ {'params': nowd_params, 'weight_decay': 0},
32
+ {'params': lr_mul_wd_params, 'lr_mul': True},
33
+ {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}]
34
+ self.optim = torch.optim.SGD(
35
+ param_list,
36
+ lr = lr0,
37
+ momentum = momentum,
38
+ weight_decay = wd)
39
+ self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)
40
+
41
+
42
+ def get_lr(self):
43
+ if self.it <= self.warmup_steps:
44
+ lr = self.warmup_start_lr*(self.warmup_factor**self.it)
45
+ else:
46
+ factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
47
+ lr = self.lr0 * factor
48
+ return lr
49
+
50
+
51
+ def step(self):
52
+ self.lr = self.get_lr()
53
+ for pg in self.optim.param_groups:
54
+ if pg.get('lr_mul', False):
55
+ pg['lr'] = self.lr * 10
56
+ else:
57
+ pg['lr'] = self.lr
58
+ if self.optim.defaults.get('lr_mul', False):
59
+ self.optim.defaults['lr'] = self.lr * 10
60
+ else:
61
+ self.optim.defaults['lr'] = self.lr
62
+ self.it += 1
63
+ self.optim.step()
64
+ if self.it == self.warmup_steps+2:
65
+ logger.info('==> warmup done, start to implement poly lr strategy')
66
+
67
+ def zero_grad(self):
68
+ self.optim.zero_grad()
69
+
models/BiSeNet/prepropess_data.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import os.path as osp
5
+ import os
6
+ import cv2
7
+ from transform import *
8
+ from PIL import Image
9
+
10
+ face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
11
+ face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
12
+ mask_path = '/home/zll/data/CelebAMask-HQ/mask'
13
+ counter = 0
14
+ total = 0
15
+ for i in range(15):
16
+
17
+ atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
18
+ 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
19
+
20
+ for j in range(i * 2000, (i + 1) * 2000):
21
+
22
+ mask = np.zeros((512, 512))
23
+
24
+ for l, att in enumerate(atts, 1):
25
+ total += 1
26
+ file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
27
+ path = osp.join(face_sep_mask, str(i), file_name)
28
+
29
+ if os.path.exists(path):
30
+ counter += 1
31
+ sep_mask = np.array(Image.open(path).convert('P'))
32
+ # print(np.unique(sep_mask))
33
+
34
+ mask[sep_mask == 225] = l
35
+ cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
36
+ print(j)
37
+
38
+ print(counter, total)
models/BiSeNet/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
models/BiSeNet/test.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from logger import setup_logger
5
+ from model import BiSeNet
6
+
7
+ import torch
8
+
9
+ import os
10
+ import os.path as osp
11
+ import numpy as np
12
+ from PIL import Image
13
+ import torchvision.transforms as transforms
14
+ import cv2
15
+
16
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
17
+ # Colors for all 20 parts
18
+ part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
19
+ [255, 0, 85], [255, 0, 170],
20
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
21
+ [0, 255, 85], [0, 255, 170],
22
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
23
+ [0, 85, 255], [0, 170, 255],
24
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
25
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
26
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
27
+
28
+ im = np.array(im)
29
+ vis_im = im.copy().astype(np.uint8)
30
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
31
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
32
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
33
+
34
+ num_of_class = np.max(vis_parsing_anno)
35
+
36
+ for pi in range(1, num_of_class + 1):
37
+ index = np.where(vis_parsing_anno == pi)
38
+ vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
39
+
40
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
41
+ # print(vis_parsing_anno_color.shape, vis_im.shape)
42
+ vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
43
+
44
+ # Save result or not
45
+ if save_im:
46
+ cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
47
+ cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
48
+
49
+ # return vis_im
50
+
51
+ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
52
+
53
+ if not os.path.exists(respth):
54
+ os.makedirs(respth)
55
+
56
+ n_classes = 19
57
+ net = BiSeNet(n_classes=n_classes)
58
+ net.cuda()
59
+ save_pth = osp.join('res/cp', cp)
60
+ net.load_state_dict(torch.load(save_pth))
61
+ net.eval()
62
+
63
+ to_tensor = transforms.Compose([
64
+ transforms.ToTensor(),
65
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
66
+ ])
67
+ with torch.no_grad():
68
+ for image_path in os.listdir(dspth):
69
+ img = Image.open(osp.join(dspth, image_path))
70
+ image = img.resize((512, 512), Image.BILINEAR)
71
+ img = to_tensor(image)
72
+ img = torch.unsqueeze(img, 0)
73
+ img = img.cuda()
74
+ out = net(img)[0]
75
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
76
+ # print(parsing)
77
+ print(np.unique(parsing))
78
+
79
+ vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+ if __name__ == "__main__":
88
+ evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth')
89
+
90
+
models/BiSeNet/train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from logger import setup_logger
5
+ from model import BiSeNet
6
+ from face_dataset import FaceMask
7
+ from loss import OhemCELoss
8
+ from evaluate import evaluate
9
+ from optimizer import Optimizer
10
+ import cv2
11
+ import numpy as np
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import DataLoader
16
+ import torch.nn.functional as F
17
+ import torch.distributed as dist
18
+
19
+ import os
20
+ import os.path as osp
21
+ import logging
22
+ import time
23
+ import datetime
24
+ import argparse
25
+
26
+
27
+ respth = './res'
28
+ if not osp.exists(respth):
29
+ os.makedirs(respth)
30
+ logger = logging.getLogger()
31
+
32
+
33
+ def parse_args():
34
+ parse = argparse.ArgumentParser()
35
+ parse.add_argument(
36
+ '--local_rank',
37
+ dest = 'local_rank',
38
+ type = int,
39
+ default = -1,
40
+ )
41
+ return parse.parse_args()
42
+
43
+
44
+ def train():
45
+ args = parse_args()
46
+ torch.cuda.set_device(args.local_rank)
47
+ dist.init_process_group(
48
+ backend = 'nccl',
49
+ init_method = 'tcp://127.0.0.1:33241',
50
+ world_size = torch.cuda.device_count(),
51
+ rank=args.local_rank
52
+ )
53
+ setup_logger(respth)
54
+
55
+ # dataset
56
+ n_classes = 19
57
+ n_img_per_gpu = 16
58
+ n_workers = 8
59
+ cropsize = [448, 448]
60
+ data_root = '/home/zll/data/CelebAMask-HQ/'
61
+
62
+ ds = FaceMask(data_root, cropsize=cropsize, mode='train')
63
+ sampler = torch.utils.data.distributed.DistributedSampler(ds)
64
+ dl = DataLoader(ds,
65
+ batch_size = n_img_per_gpu,
66
+ shuffle = False,
67
+ sampler = sampler,
68
+ num_workers = n_workers,
69
+ pin_memory = True,
70
+ drop_last = True)
71
+
72
+ # model
73
+ ignore_idx = -100
74
+ net = BiSeNet(n_classes=n_classes)
75
+ net.cuda()
76
+ net.train()
77
+ net = nn.parallel.DistributedDataParallel(net,
78
+ device_ids = [args.local_rank, ],
79
+ output_device = args.local_rank
80
+ )
81
+ score_thres = 0.7
82
+ n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
83
+ LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
84
+ Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
85
+ Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
86
+
87
+ ## optimizer
88
+ momentum = 0.9
89
+ weight_decay = 5e-4
90
+ lr_start = 1e-2
91
+ max_iter = 80000
92
+ power = 0.9
93
+ warmup_steps = 1000
94
+ warmup_start_lr = 1e-5
95
+ optim = Optimizer(
96
+ model = net.module,
97
+ lr0 = lr_start,
98
+ momentum = momentum,
99
+ wd = weight_decay,
100
+ warmup_steps = warmup_steps,
101
+ warmup_start_lr = warmup_start_lr,
102
+ max_iter = max_iter,
103
+ power = power)
104
+
105
+ ## train loop
106
+ msg_iter = 50
107
+ loss_avg = []
108
+ st = glob_st = time.time()
109
+ diter = iter(dl)
110
+ epoch = 0
111
+ for it in range(max_iter):
112
+ try:
113
+ im, lb = next(diter)
114
+ if not im.size()[0] == n_img_per_gpu:
115
+ raise StopIteration
116
+ except StopIteration:
117
+ epoch += 1
118
+ sampler.set_epoch(epoch)
119
+ diter = iter(dl)
120
+ im, lb = next(diter)
121
+ im = im.cuda()
122
+ lb = lb.cuda()
123
+ H, W = im.size()[2:]
124
+ lb = torch.squeeze(lb, 1)
125
+
126
+ optim.zero_grad()
127
+ out, out16, out32 = net(im)
128
+ lossp = LossP(out, lb)
129
+ loss2 = Loss2(out16, lb)
130
+ loss3 = Loss3(out32, lb)
131
+ loss = lossp + loss2 + loss3
132
+ loss.backward()
133
+ optim.step()
134
+
135
+ loss_avg.append(loss.item())
136
+
137
+ # print training log message
138
+ if (it+1) % msg_iter == 0:
139
+ loss_avg = sum(loss_avg) / len(loss_avg)
140
+ lr = optim.lr
141
+ ed = time.time()
142
+ t_intv, glob_t_intv = ed - st, ed - glob_st
143
+ eta = int((max_iter - it) * (glob_t_intv / it))
144
+ eta = str(datetime.timedelta(seconds=eta))
145
+ msg = ', '.join([
146
+ 'it: {it}/{max_it}',
147
+ 'lr: {lr:4f}',
148
+ 'loss: {loss:.4f}',
149
+ 'eta: {eta}',
150
+ 'time: {time:.4f}',
151
+ ]).format(
152
+ it = it+1,
153
+ max_it = max_iter,
154
+ lr = lr,
155
+ loss = loss_avg,
156
+ time = t_intv,
157
+ eta = eta
158
+ )
159
+ logger.info(msg)
160
+ loss_avg = []
161
+ st = ed
162
+ if dist.get_rank() == 0:
163
+ if (it+1) % 5000 == 0:
164
+ state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
165
+ if dist.get_rank() == 0:
166
+ torch.save(state, './res/cp/{}_iter.pth'.format(it))
167
+ evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
168
+
169
+ # dump the final model
170
+ save_pth = osp.join(respth, 'model_final_diss.pth')
171
+ # net.cpu()
172
+ state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
173
+ if dist.get_rank() == 0:
174
+ torch.save(state, save_pth)
175
+ logger.info('training done, model saved to: {}'.format(save_pth))
176
+
177
+
178
+ if __name__ == "__main__":
179
+ train()
models/BiSeNet/transform.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ from PIL import Image
6
+ import PIL.ImageEnhance as ImageEnhance
7
+ import random
8
+ import numpy as np
9
+
10
+ class RandomCrop(object):
11
+ def __init__(self, size, *args, **kwargs):
12
+ self.size = size
13
+
14
+ def __call__(self, im_lb):
15
+ im = im_lb['im']
16
+ lb = im_lb['lb']
17
+ assert im.size == lb.size
18
+ W, H = self.size
19
+ w, h = im.size
20
+
21
+ if (W, H) == (w, h): return dict(im=im, lb=lb)
22
+ if w < W or h < H:
23
+ scale = float(W) / w if w < h else float(H) / h
24
+ w, h = int(scale * w + 1), int(scale * h + 1)
25
+ im = im.resize((w, h), Image.BILINEAR)
26
+ lb = lb.resize((w, h), Image.NEAREST)
27
+ sw, sh = random.random() * (w - W), random.random() * (h - H)
28
+ crop = int(sw), int(sh), int(sw) + W, int(sh) + H
29
+ return dict(
30
+ im = im.crop(crop),
31
+ lb = lb.crop(crop)
32
+ )
33
+
34
+
35
+ class HorizontalFlip(object):
36
+ def __init__(self, p=0.5, *args, **kwargs):
37
+ self.p = p
38
+
39
+ def __call__(self, im_lb):
40
+ if random.random() > self.p:
41
+ return im_lb
42
+ else:
43
+ im = im_lb['im']
44
+ lb = im_lb['lb']
45
+
46
+ # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
47
+ # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
48
+
49
+ flip_lb = np.array(lb)
50
+ flip_lb[lb == 2] = 3
51
+ flip_lb[lb == 3] = 2
52
+ flip_lb[lb == 4] = 5
53
+ flip_lb[lb == 5] = 4
54
+ flip_lb[lb == 7] = 8
55
+ flip_lb[lb == 8] = 7
56
+ flip_lb = Image.fromarray(flip_lb)
57
+ return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
58
+ lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
59
+ )
60
+
61
+
62
+ class RandomScale(object):
63
+ def __init__(self, scales=(1, ), *args, **kwargs):
64
+ self.scales = scales
65
+
66
+ def __call__(self, im_lb):
67
+ im = im_lb['im']
68
+ lb = im_lb['lb']
69
+ W, H = im.size
70
+ scale = random.choice(self.scales)
71
+ w, h = int(W * scale), int(H * scale)
72
+ return dict(im = im.resize((w, h), Image.BILINEAR),
73
+ lb = lb.resize((w, h), Image.NEAREST),
74
+ )
75
+
76
+
77
+ class ColorJitter(object):
78
+ def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
79
+ if not brightness is None and brightness>0:
80
+ self.brightness = [max(1-brightness, 0), 1+brightness]
81
+ if not contrast is None and contrast>0:
82
+ self.contrast = [max(1-contrast, 0), 1+contrast]
83
+ if not saturation is None and saturation>0:
84
+ self.saturation = [max(1-saturation, 0), 1+saturation]
85
+
86
+ def __call__(self, im_lb):
87
+ im = im_lb['im']
88
+ lb = im_lb['lb']
89
+ r_brightness = random.uniform(self.brightness[0], self.brightness[1])
90
+ r_contrast = random.uniform(self.contrast[0], self.contrast[1])
91
+ r_saturation = random.uniform(self.saturation[0], self.saturation[1])
92
+ im = ImageEnhance.Brightness(im).enhance(r_brightness)
93
+ im = ImageEnhance.Contrast(im).enhance(r_contrast)
94
+ im = ImageEnhance.Color(im).enhance(r_saturation)
95
+ return dict(im = im,
96
+ lb = lb,
97
+ )
98
+
99
+
100
+ class MultiScale(object):
101
+ def __init__(self, scales):
102
+ self.scales = scales
103
+
104
+ def __call__(self, img):
105
+ W, H = img.size
106
+ sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
107
+ imgs = []
108
+ [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
109
+ return imgs
110
+
111
+
112
+ class Compose(object):
113
+ def __init__(self, do_list):
114
+ self.do_list = do_list
115
+
116
+ def __call__(self, im_lb):
117
+ for comp in self.do_list:
118
+ im_lb = comp(im_lb)
119
+ return im_lb
120
+
121
+
122
+
123
+
124
+ if __name__ == '__main__':
125
+ flip = HorizontalFlip(p = 1)
126
+ crop = RandomCrop((321, 321))
127
+ rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
128
+ img = Image.open('data/img.jpg')
129
+ lb = Image.open('data/label.png')
models/BiSeNet_pretrained_for_ConsistentID.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
models/LLaVA/.devcontainer/Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM mcr.microsoft.com/devcontainers/base:ubuntu-20.04
2
+
3
+ SHELL [ "bash", "-c" ]
4
+
5
+ # update apt and install packages
6
+ RUN apt update && \
7
+ apt install -yq \
8
+ ffmpeg \
9
+ dkms \
10
+ build-essential
11
+
12
+ # add user tools
13
+ RUN sudo apt install -yq \
14
+ jq \
15
+ jp \
16
+ tree \
17
+ tldr
18
+
19
+ # add git-lfs and install
20
+ RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \
21
+ sudo apt-get install -yq git-lfs && \
22
+ git lfs install
23
+
24
+ ############################################
25
+ # Setup user
26
+ ############################################
27
+
28
+ USER vscode
29
+
30
+ # install azcopy, a tool to copy to/from blob storage
31
+ # for more info: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-upload#upload-a-file
32
+ RUN cd /tmp && \
33
+ wget https://azcopyvnext.azureedge.net/release20230123/azcopy_linux_amd64_10.17.0.tar.gz && \
34
+ tar xvf azcopy_linux_amd64_10.17.0.tar.gz && \
35
+ mkdir -p ~/.local/bin && \
36
+ mv azcopy_linux_amd64_10.17.0/azcopy ~/.local/bin && \
37
+ chmod +x ~/.local/bin/azcopy && \
38
+ rm -rf azcopy_linux_amd64*
39
+
40
+ # Setup conda
41
+ RUN cd /tmp && \
42
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
43
+ bash ./Miniconda3-latest-Linux-x86_64.sh -b && \
44
+ rm ./Miniconda3-latest-Linux-x86_64.sh
45
+
46
+ # Install dotnet
47
+ RUN cd /tmp && \
48
+ wget https://dot.net/v1/dotnet-install.sh && \
49
+ chmod +x dotnet-install.sh && \
50
+ ./dotnet-install.sh --channel 7.0 && \
51
+ ./dotnet-install.sh --channel 3.1 && \
52
+ rm ./dotnet-install.sh
53
+
models/LLaVA/.devcontainer/devcontainer.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ SAMPLE_ENV_VAR1="Sample Value"
2
+ SAMPLE_ENV_VAR2=332431bf-68bf
models/LLaVA/.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "LLaVA",
3
+ "build": {
4
+ "dockerfile": "Dockerfile",
5
+ "context": "..",
6
+ "args": {}
7
+ },
8
+ "features": {
9
+ "ghcr.io/devcontainers/features/docker-in-docker:2": {},
10
+ "ghcr.io/devcontainers/features/azure-cli:1": {},
11
+ "ghcr.io/azure/azure-dev/azd:0": {},
12
+ "ghcr.io/devcontainers/features/powershell:1": {},
13
+ "ghcr.io/devcontainers/features/common-utils:2": {},
14
+ "ghcr.io/devcontainers-contrib/features/zsh-plugins:0": {},
15
+ },
16
+ // "forwardPorts": [],
17
+ "postCreateCommand": "bash ./.devcontainer/postCreateCommand.sh",
18
+ "customizations": {
19
+ "vscode": {
20
+ "settings": {
21
+ "python.analysis.autoImportCompletions": true,
22
+ "python.analysis.autoImportUserSymbols": true,
23
+ "python.defaultInterpreterPath": "~/miniconda3/envs/llava/bin/python",
24
+ "python.formatting.provider": "yapf",
25
+ "python.linting.enabled": true,
26
+ "python.linting.flake8Enabled": true,
27
+ "isort.check": true,
28
+ "dev.containers.copyGitConfig": true,
29
+ "terminal.integrated.defaultProfile.linux": "zsh",
30
+ "terminal.integrated.profiles.linux": {
31
+ "zsh": {
32
+ "path": "/usr/bin/zsh"
33
+ },
34
+ }
35
+ },
36
+ "extensions": [
37
+ "aaron-bond.better-comments",
38
+ "eamodio.gitlens",
39
+ "EditorConfig.EditorConfig",
40
+ "foxundermoon.shell-format",
41
+ "GitHub.copilot-chat",
42
+ "GitHub.copilot-labs",
43
+ "GitHub.copilot",
44
+ "lehoanganh298.json-lines-viewer",
45
+ "mhutchie.git-graph",
46
+ "ms-azuretools.vscode-docker",
47
+ "ms-dotnettools.dotnet-interactive-vscode",
48
+ "ms-python.flake8",
49
+ "ms-python.isort",
50
+ "ms-python.python",
51
+ "ms-python.vscode-pylance",
52
+ "njpwerner.autodocstring",
53
+ "redhat.vscode-yaml",
54
+ "stkb.rewrap",
55
+ "yzhang.markdown-all-in-one",
56
+ ]
57
+ }
58
+ },
59
+ "mounts": [],
60
+ "runArgs": [
61
+ "--gpus",
62
+ "all",
63
+ // "--ipc",
64
+ // "host",
65
+ "--ulimit",
66
+ "memlock=-1",
67
+ "--env-file",
68
+ ".devcontainer/devcontainer.env"
69
+ ],
70
+ // "remoteUser": "root"
71
+ }