Spaces:
Runtime error
Runtime error
remove MeshLab dependency with Open3D
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +8 -6
- apps/IFGeo.py +37 -45
- apps/Normal.py +37 -40
- apps/avatarizer.py +28 -13
- apps/infer.py +211 -87
- apps/multi_render.py +1 -3
- configs/econ.yaml +1 -0
- docs/installation.md +2 -3
- lib/common/BNI.py +22 -22
- lib/common/BNI_utils.py +104 -93
- lib/common/blender_utils.py +29 -22
- lib/common/cloth_extraction.py +13 -14
- lib/common/config.py +7 -9
- lib/common/imutils.py +147 -338
- lib/common/libmesh/inside_mesh.py +9 -11
- lib/common/libmesh/setup.py +1 -4
- lib/common/libvoxelize/setup.py +1 -2
- lib/common/local_affine.py +31 -21
- lib/common/render.py +79 -47
- lib/common/render_utils.py +16 -25
- lib/common/seg3d_lossless.py +130 -154
- lib/common/seg3d_utils.py +67 -117
- lib/common/train_util.py +25 -445
- lib/common/voxelize.py +91 -80
- lib/dataset/Evaluator.py +29 -19
- lib/dataset/NormalDataset.py +45 -56
- lib/dataset/NormalModule.py +2 -3
- lib/dataset/PointFeat.py +10 -4
- lib/dataset/TestDataset.py +44 -16
- lib/dataset/body_model.py +68 -91
- lib/dataset/mesh_util.py +112 -416
- lib/net/BasePIFuNet.py +3 -4
- lib/net/Discriminator.py +76 -65
- lib/net/FBNet.py +99 -92
- lib/net/GANLoss.py +2 -3
- lib/net/IFGeoNet.py +72 -64
- lib/net/IFGeoNet_nobody.py +54 -44
- lib/net/NormalNet.py +13 -10
- lib/net/geometry.py +68 -56
- lib/net/net_util.py +22 -27
- lib/net/voxelize.py +21 -15
- lib/pixielib/models/FLAME.py +13 -15
- lib/pixielib/models/SMPLX.py +495 -502
- lib/pixielib/models/encoders.py +2 -5
- lib/pixielib/models/hrnet.py +108 -152
- lib/pixielib/models/lbs.py +25 -42
- lib/pixielib/models/moderators.py +2 -10
- lib/pixielib/models/resnet.py +12 -42
- lib/pixielib/pixie.py +102 -136
- lib/pixielib/utils/array_cropper.py +15 -20
README.md
CHANGED
@@ -103,20 +103,23 @@ python -m apps.avatarizer -n {filename}
|
|
103 |
|
104 |
### Some adjustable parameters in _config/econ.yaml_
|
105 |
|
106 |
-
- `use_ifnet:
|
107 |
-
- True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality)
|
108 |
-
- False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed)
|
109 |
- `use_smpl: ["hand", "face"]`
|
110 |
- [ ]: don't use either hands or face parts from SMPL-X
|
111 |
- ["hand"]: only use the **visible** hands from SMPL-X
|
112 |
- ["hand", "face"]: use both **visible** hands and face from SMPL-X
|
113 |
- `thickness: 2cm`
|
114 |
- could be increased accordingly in case final reconstruction **xx_full.obj** looks flat
|
|
|
|
|
115 |
- `hps_type: PIXIE`
|
116 |
- "pixie": more accurate for face and hands
|
117 |
- "pymafx": more robust for challenging poses
|
118 |
-
- `
|
119 |
-
-
|
|
|
120 |
|
121 |
<br/>
|
122 |
|
@@ -160,7 +163,6 @@ Here are some great resources we benefit from:
|
|
160 |
- [BiNI](https://github.com/hoshino042/bilateral_normal_integration) for Bilateral Normal Integration
|
161 |
- [MonoPortDataset](https://github.com/Project-Splinter/MonoPortDataset) for Data Processing, [MonoPort](https://github.com/Project-Splinter/MonoPort) for fast implicit surface query
|
162 |
- [rembg](https://github.com/danielgatis/rembg) for Human Segmentation
|
163 |
-
- [pypoisson](https://github.com/mmolero/pypoisson) for poisson reconstruction
|
164 |
- [MediaPipe](https://google.github.io/mediapipe/getting_started/python.html) for full-body landmark estimation
|
165 |
- [PyTorch-NICP](https://github.com/wuhaozhe/pytorch-nicp) for non-rigid registration
|
166 |
- [smplx](https://github.com/vchoutas/smplx), [PyMAF-X](https://www.liuyebin.com/pymaf-x/), [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation
|
|
|
103 |
|
104 |
### Some adjustable parameters in _config/econ.yaml_
|
105 |
|
106 |
+
- `use_ifnet: False`
|
107 |
+
- True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality, **~2min / img**)
|
108 |
+
- False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed, **~1.5min / img**)
|
109 |
- `use_smpl: ["hand", "face"]`
|
110 |
- [ ]: don't use either hands or face parts from SMPL-X
|
111 |
- ["hand"]: only use the **visible** hands from SMPL-X
|
112 |
- ["hand", "face"]: use both **visible** hands and face from SMPL-X
|
113 |
- `thickness: 2cm`
|
114 |
- could be increased accordingly in case final reconstruction **xx_full.obj** looks flat
|
115 |
+
- `k: 4`
|
116 |
+
- could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts
|
117 |
- `hps_type: PIXIE`
|
118 |
- "pixie": more accurate for face and hands
|
119 |
- "pymafx": more robust for challenging poses
|
120 |
+
- `texture_src: image`
|
121 |
+
- "image": direct mapping the aligned pixels to final mesh
|
122 |
+
- "SD": use Stable Diffusion to generate full texture (TODO)
|
123 |
|
124 |
<br/>
|
125 |
|
|
|
163 |
- [BiNI](https://github.com/hoshino042/bilateral_normal_integration) for Bilateral Normal Integration
|
164 |
- [MonoPortDataset](https://github.com/Project-Splinter/MonoPortDataset) for Data Processing, [MonoPort](https://github.com/Project-Splinter/MonoPort) for fast implicit surface query
|
165 |
- [rembg](https://github.com/danielgatis/rembg) for Human Segmentation
|
|
|
166 |
- [MediaPipe](https://google.github.io/mediapipe/getting_started/python.html) for full-body landmark estimation
|
167 |
- [PyTorch-NICP](https://github.com/wuhaozhe/pytorch-nicp) for non-rigid registration
|
168 |
- [smplx](https://github.com/vchoutas/smplx), [PyMAF-X](https://www.liuyebin.com/pymaf-x/), [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation
|
apps/IFGeo.py
CHANGED
@@ -24,7 +24,6 @@ torch.backends.cudnn.benchmark = True
|
|
24 |
|
25 |
|
26 |
class IFGeo(pl.LightningModule):
|
27 |
-
|
28 |
def __init__(self, cfg):
|
29 |
super(IFGeo, self).__init__()
|
30 |
|
@@ -44,14 +43,15 @@ class IFGeo(pl.LightningModule):
|
|
44 |
from lib.net.IFGeoNet_nobody import IFGeoNet
|
45 |
self.netG = IFGeoNet(cfg)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
self.resolutions = self.resolutions.astype(np.int16).tolist()
|
57 |
|
@@ -82,9 +82,9 @@ class IFGeo(pl.LightningModule):
|
|
82 |
|
83 |
if self.cfg.optim == "Adadelta":
|
84 |
|
85 |
-
optimizer_G = torch.optim.Adadelta(
|
86 |
-
|
87 |
-
|
88 |
|
89 |
elif self.cfg.optim == "Adam":
|
90 |
|
@@ -103,20 +103,14 @@ class IFGeo(pl.LightningModule):
|
|
103 |
raise NotImplementedError
|
104 |
|
105 |
# set scheduler
|
106 |
-
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
|
107 |
-
|
108 |
-
|
109 |
|
110 |
return [optimizer_G], [scheduler_G]
|
111 |
|
112 |
def training_step(self, batch, batch_idx):
|
113 |
|
114 |
-
# cfg log
|
115 |
-
if self.cfg.devices == 1:
|
116 |
-
if not self.cfg.fast_dev and self.global_step == 0:
|
117 |
-
export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg)
|
118 |
-
self.logger.experiment.config.update(convert_to_dict(self.cfg))
|
119 |
-
|
120 |
self.netG.train()
|
121 |
|
122 |
preds_G = self.netG(batch)
|
@@ -127,12 +121,9 @@ class IFGeo(pl.LightningModule):
|
|
127 |
"loss": error_G,
|
128 |
}
|
129 |
|
130 |
-
self.log_dict(
|
131 |
-
|
132 |
-
|
133 |
-
on_step=True,
|
134 |
-
on_epoch=False,
|
135 |
-
sync_dist=True)
|
136 |
|
137 |
return metrics_log
|
138 |
|
@@ -143,12 +134,14 @@ class IFGeo(pl.LightningModule):
|
|
143 |
"train/avgloss": batch_mean(outputs, "loss"),
|
144 |
}
|
145 |
|
146 |
-
self.log_dict(
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
152 |
|
153 |
def validation_step(self, batch, batch_idx):
|
154 |
|
@@ -162,12 +155,9 @@ class IFGeo(pl.LightningModule):
|
|
162 |
"val/loss": error_G,
|
163 |
}
|
164 |
|
165 |
-
self.log_dict(
|
166 |
-
|
167 |
-
|
168 |
-
on_step=True,
|
169 |
-
on_epoch=False,
|
170 |
-
sync_dist=True)
|
171 |
|
172 |
return metrics_log
|
173 |
|
@@ -178,9 +168,11 @@ class IFGeo(pl.LightningModule):
|
|
178 |
"val/avgloss": batch_mean(outputs, "val/loss"),
|
179 |
}
|
180 |
|
181 |
-
self.log_dict(
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class IFGeo(pl.LightningModule):
|
|
|
27 |
def __init__(self, cfg):
|
28 |
super(IFGeo, self).__init__()
|
29 |
|
|
|
43 |
from lib.net.IFGeoNet_nobody import IFGeoNet
|
44 |
self.netG = IFGeoNet(cfg)
|
45 |
|
46 |
+
self.resolutions = (
|
47 |
+
np.logspace(
|
48 |
+
start=5,
|
49 |
+
stop=np.log2(self.mcube_res),
|
50 |
+
base=2,
|
51 |
+
num=int(np.log2(self.mcube_res) - 4),
|
52 |
+
endpoint=True,
|
53 |
+
) + 1.0
|
54 |
+
)
|
55 |
|
56 |
self.resolutions = self.resolutions.astype(np.int16).tolist()
|
57 |
|
|
|
82 |
|
83 |
if self.cfg.optim == "Adadelta":
|
84 |
|
85 |
+
optimizer_G = torch.optim.Adadelta(
|
86 |
+
optim_params_G, lr=self.lr_G, weight_decay=weight_decay
|
87 |
+
)
|
88 |
|
89 |
elif self.cfg.optim == "Adam":
|
90 |
|
|
|
103 |
raise NotImplementedError
|
104 |
|
105 |
# set scheduler
|
106 |
+
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
|
107 |
+
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
108 |
+
)
|
109 |
|
110 |
return [optimizer_G], [scheduler_G]
|
111 |
|
112 |
def training_step(self, batch, batch_idx):
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
self.netG.train()
|
115 |
|
116 |
preds_G = self.netG(batch)
|
|
|
121 |
"loss": error_G,
|
122 |
}
|
123 |
|
124 |
+
self.log_dict(
|
125 |
+
metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
|
126 |
+
)
|
|
|
|
|
|
|
127 |
|
128 |
return metrics_log
|
129 |
|
|
|
134 |
"train/avgloss": batch_mean(outputs, "loss"),
|
135 |
}
|
136 |
|
137 |
+
self.log_dict(
|
138 |
+
metrics_log,
|
139 |
+
prog_bar=False,
|
140 |
+
logger=True,
|
141 |
+
on_step=False,
|
142 |
+
on_epoch=True,
|
143 |
+
rank_zero_only=True
|
144 |
+
)
|
145 |
|
146 |
def validation_step(self, batch, batch_idx):
|
147 |
|
|
|
155 |
"val/loss": error_G,
|
156 |
}
|
157 |
|
158 |
+
self.log_dict(
|
159 |
+
metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True
|
160 |
+
)
|
|
|
|
|
|
|
161 |
|
162 |
return metrics_log
|
163 |
|
|
|
168 |
"val/avgloss": batch_mean(outputs, "val/loss"),
|
169 |
}
|
170 |
|
171 |
+
self.log_dict(
|
172 |
+
metrics_log,
|
173 |
+
prog_bar=False,
|
174 |
+
logger=True,
|
175 |
+
on_step=False,
|
176 |
+
on_epoch=True,
|
177 |
+
rank_zero_only=True
|
178 |
+
)
|
apps/Normal.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1 |
from lib.net import NormalNet
|
2 |
-
from lib.common.train_util import
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
-
import os.path as osp
|
6 |
from skimage.transform import resize
|
7 |
import pytorch_lightning as pl
|
8 |
|
9 |
|
10 |
class Normal(pl.LightningModule):
|
11 |
-
|
12 |
def __init__(self, cfg):
|
13 |
super(Normal, self).__init__()
|
14 |
self.cfg = cfg
|
@@ -44,19 +42,19 @@ class Normal(pl.LightningModule):
|
|
44 |
optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999))
|
45 |
optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999))
|
46 |
|
47 |
-
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
52 |
-
|
53 |
-
|
54 |
if 'gan' in self.ALL_losses:
|
55 |
optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}]
|
56 |
optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999))
|
57 |
-
scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(
|
58 |
-
|
59 |
-
|
60 |
self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D]
|
61 |
optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D]
|
62 |
|
@@ -77,19 +75,16 @@ class Normal(pl.LightningModule):
|
|
77 |
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0),
|
78 |
(height, height),
|
79 |
anti_aliasing=True,
|
80 |
-
)
|
|
|
81 |
|
82 |
-
self.logger.log_image(
|
83 |
-
|
84 |
-
|
|
|
85 |
|
86 |
def training_step(self, batch, batch_idx):
|
87 |
|
88 |
-
# cfg log
|
89 |
-
if not self.cfg.fast_dev and self.global_step == 0 and self.cfg.devices == 1:
|
90 |
-
export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg)
|
91 |
-
self.logger.experiment.config.update(convert_to_dict(self.cfg))
|
92 |
-
|
93 |
self.netG.train()
|
94 |
|
95 |
# retrieve the data
|
@@ -125,7 +120,8 @@ class Normal(pl.LightningModule):
|
|
125 |
opt_B.step()
|
126 |
|
127 |
if batch_idx > 0 and batch_idx % int(
|
128 |
-
|
|
|
129 |
|
130 |
self.netG.eval()
|
131 |
with torch.no_grad():
|
@@ -142,12 +138,9 @@ class Normal(pl.LightningModule):
|
|
142 |
for key in error_dict.keys():
|
143 |
metrics_log["train/loss_" + key] = error_dict[key].item()
|
144 |
|
145 |
-
self.log_dict(
|
146 |
-
|
147 |
-
|
148 |
-
on_step=True,
|
149 |
-
on_epoch=False,
|
150 |
-
sync_dist=True)
|
151 |
|
152 |
return metrics_log
|
153 |
|
@@ -163,12 +156,14 @@ class Normal(pl.LightningModule):
|
|
163 |
loss_name = key
|
164 |
metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
|
165 |
|
166 |
-
self.log_dict(
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
|
173 |
def validation_step(self, batch, batch_idx):
|
174 |
|
@@ -212,9 +207,11 @@ class Normal(pl.LightningModule):
|
|
212 |
[stage, loss_name] = key.split("/")
|
213 |
metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
|
214 |
|
215 |
-
self.log_dict(
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
1 |
from lib.net import NormalNet
|
2 |
+
from lib.common.train_util import batch_mean
|
3 |
import torch
|
4 |
import numpy as np
|
|
|
5 |
from skimage.transform import resize
|
6 |
import pytorch_lightning as pl
|
7 |
|
8 |
|
9 |
class Normal(pl.LightningModule):
|
|
|
10 |
def __init__(self, cfg):
|
11 |
super(Normal, self).__init__()
|
12 |
self.cfg = cfg
|
|
|
42 |
optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999))
|
43 |
optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999))
|
44 |
|
45 |
+
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
46 |
+
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
47 |
+
)
|
48 |
|
49 |
+
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
50 |
+
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
51 |
+
)
|
52 |
if 'gan' in self.ALL_losses:
|
53 |
optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}]
|
54 |
optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999))
|
55 |
+
scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(
|
56 |
+
optimizer_N_D, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
57 |
+
)
|
58 |
self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D]
|
59 |
optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D]
|
60 |
|
|
|
75 |
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0),
|
76 |
(height, height),
|
77 |
anti_aliasing=True,
|
78 |
+
)
|
79 |
+
)
|
80 |
|
81 |
+
self.logger.log_image(
|
82 |
+
key=f"Normal/{dataset}/{idx if not self.overfit else 1}",
|
83 |
+
images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)]
|
84 |
+
)
|
85 |
|
86 |
def training_step(self, batch, batch_idx):
|
87 |
|
|
|
|
|
|
|
|
|
|
|
88 |
self.netG.train()
|
89 |
|
90 |
# retrieve the data
|
|
|
120 |
opt_B.step()
|
121 |
|
122 |
if batch_idx > 0 and batch_idx % int(
|
123 |
+
self.cfg.freq_show_train
|
124 |
+
) == 0 and self.cfg.devices == 1:
|
125 |
|
126 |
self.netG.eval()
|
127 |
with torch.no_grad():
|
|
|
138 |
for key in error_dict.keys():
|
139 |
metrics_log["train/loss_" + key] = error_dict[key].item()
|
140 |
|
141 |
+
self.log_dict(
|
142 |
+
metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True
|
143 |
+
)
|
|
|
|
|
|
|
144 |
|
145 |
return metrics_log
|
146 |
|
|
|
156 |
loss_name = key
|
157 |
metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
|
158 |
|
159 |
+
self.log_dict(
|
160 |
+
metrics_log,
|
161 |
+
prog_bar=False,
|
162 |
+
logger=True,
|
163 |
+
on_step=False,
|
164 |
+
on_epoch=True,
|
165 |
+
rank_zero_only=True
|
166 |
+
)
|
167 |
|
168 |
def validation_step(self, batch, batch_idx):
|
169 |
|
|
|
207 |
[stage, loss_name] = key.split("/")
|
208 |
metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key)
|
209 |
|
210 |
+
self.log_dict(
|
211 |
+
metrics_log,
|
212 |
+
prog_bar=False,
|
213 |
+
logger=True,
|
214 |
+
on_step=False,
|
215 |
+
on_epoch=True,
|
216 |
+
rank_zero_only=True
|
217 |
+
)
|
apps/avatarizer.py
CHANGED
@@ -44,7 +44,8 @@ smpl_model = smplx.create(
|
|
44 |
use_pca=False,
|
45 |
num_betas=200,
|
46 |
num_expression_coeffs=50,
|
47 |
-
ext='pkl'
|
|
|
48 |
|
49 |
smpl_out_lst = []
|
50 |
|
@@ -62,7 +63,9 @@ for pose_type in ["t-pose", "da-pose", "pose"]:
|
|
62 |
return_full_pose=True,
|
63 |
return_joint_transformation=True,
|
64 |
return_vertex_transformation=True,
|
65 |
-
pose_type=pose_type
|
|
|
|
|
66 |
|
67 |
smpl_verts = smpl_out_lst[2].vertices.detach()[0]
|
68 |
smpl_tree = cKDTree(smpl_verts.cpu().numpy())
|
@@ -74,7 +77,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
|
|
74 |
econ_verts = torch.tensor(econ_obj.vertices).float()
|
75 |
rot_mat_t = smpl_out_lst[2].vertex_transformation.detach()[0][idx[:, 0]]
|
76 |
homo_coord = torch.ones_like(econ_verts)[..., :1]
|
77 |
-
econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord],
|
|
|
78 |
econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
|
79 |
econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
|
80 |
|
@@ -84,7 +88,9 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
|
|
84 |
econ_da = trimesh.Trimesh(econ_da_verts[:, :3, 0].cpu(), econ_obj.faces)
|
85 |
|
86 |
# da-pose for SMPL-X
|
87 |
-
smpl_da = trimesh.Trimesh(
|
|
|
|
|
88 |
smpl_da.export(f"{prefix}_smpl_da.obj")
|
89 |
|
90 |
# remove hands from ECON for next registeration
|
@@ -97,7 +103,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
|
|
97 |
# remove SMPL-X hand and face
|
98 |
register_mask = ~np.isin(
|
99 |
np.arange(smpl_da.vertices.shape[0]),
|
100 |
-
np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid])
|
|
|
101 |
register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
|
102 |
smpl_da_body = smpl_da.copy()
|
103 |
smpl_da_body.update_faces(register_mask[smpl_da.faces].all(axis=1))
|
@@ -115,8 +122,13 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da.
|
|
115 |
# remove over-streched+hand faces from ECON
|
116 |
econ_da_body = econ_da.copy()
|
117 |
edge_before = np.sqrt(
|
118 |
-
((econ_obj.vertices[econ_cano.edges[:, 0]] -
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
120 |
edge_diff = edge_after / edge_before.clip(1e-2)
|
121 |
streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
|
122 |
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
|
@@ -148,8 +160,9 @@ econ_J_regressor = (smpl_model.J_regressor[:, idx] * knn_weights[None]).sum(axis
|
|
148 |
econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
|
149 |
|
150 |
num_posedirs = smpl_model.posedirs.shape[0]
|
151 |
-
econ_posedirs = (
|
152 |
-
|
|
|
153 |
|
154 |
econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
|
155 |
econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
|
@@ -157,8 +170,9 @@ econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
|
|
157 |
# re-compute da-pose rot_mat for ECON
|
158 |
rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]]
|
159 |
econ_da_verts = torch.tensor(econ_da.vertices).float()
|
160 |
-
econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat(
|
161 |
-
|
|
|
162 |
econ_cano_verts = econ_cano_verts[:, :3, 0].double()
|
163 |
|
164 |
# ----------------------------------------------------
|
@@ -174,7 +188,8 @@ posed_econ_verts, _ = general_lbs(
|
|
174 |
posedirs=econ_posedirs,
|
175 |
J_regressor=econ_J_regressor,
|
176 |
parents=smpl_model.parents,
|
177 |
-
lbs_weights=econ_lbs_weights
|
|
|
178 |
|
179 |
econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces)
|
180 |
-
econ_pose.export(f"{prefix}_econ_pose.obj")
|
|
|
44 |
use_pca=False,
|
45 |
num_betas=200,
|
46 |
num_expression_coeffs=50,
|
47 |
+
ext='pkl'
|
48 |
+
)
|
49 |
|
50 |
smpl_out_lst = []
|
51 |
|
|
|
63 |
return_full_pose=True,
|
64 |
return_joint_transformation=True,
|
65 |
return_vertex_transformation=True,
|
66 |
+
pose_type=pose_type
|
67 |
+
)
|
68 |
+
)
|
69 |
|
70 |
smpl_verts = smpl_out_lst[2].vertices.detach()[0]
|
71 |
smpl_tree = cKDTree(smpl_verts.cpu().numpy())
|
|
|
77 |
econ_verts = torch.tensor(econ_obj.vertices).float()
|
78 |
rot_mat_t = smpl_out_lst[2].vertex_transformation.detach()[0][idx[:, 0]]
|
79 |
homo_coord = torch.ones_like(econ_verts)[..., :1]
|
80 |
+
econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord],
|
81 |
+
dim=1).unsqueeze(-1)
|
82 |
econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
|
83 |
econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
|
84 |
|
|
|
88 |
econ_da = trimesh.Trimesh(econ_da_verts[:, :3, 0].cpu(), econ_obj.faces)
|
89 |
|
90 |
# da-pose for SMPL-X
|
91 |
+
smpl_da = trimesh.Trimesh(
|
92 |
+
smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False
|
93 |
+
)
|
94 |
smpl_da.export(f"{prefix}_smpl_da.obj")
|
95 |
|
96 |
# remove hands from ECON for next registeration
|
|
|
103 |
# remove SMPL-X hand and face
|
104 |
register_mask = ~np.isin(
|
105 |
np.arange(smpl_da.vertices.shape[0]),
|
106 |
+
np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid])
|
107 |
+
)
|
108 |
register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
|
109 |
smpl_da_body = smpl_da.copy()
|
110 |
smpl_da_body.update_faces(register_mask[smpl_da.faces].all(axis=1))
|
|
|
122 |
# remove over-streched+hand faces from ECON
|
123 |
econ_da_body = econ_da.copy()
|
124 |
edge_before = np.sqrt(
|
125 |
+
((econ_obj.vertices[econ_cano.edges[:, 0]] -
|
126 |
+
econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)
|
127 |
+
)
|
128 |
+
edge_after = np.sqrt(
|
129 |
+
((econ_da.vertices[econ_cano.edges[:, 0]] -
|
130 |
+
econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)
|
131 |
+
)
|
132 |
edge_diff = edge_after / edge_before.clip(1e-2)
|
133 |
streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
|
134 |
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
|
|
|
160 |
econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
|
161 |
|
162 |
num_posedirs = smpl_model.posedirs.shape[0]
|
163 |
+
econ_posedirs = (
|
164 |
+
smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] * knn_weights[None, ..., None]
|
165 |
+
).sum(axis=-2).view(num_posedirs, -1).float()
|
166 |
|
167 |
econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
|
168 |
econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
|
|
|
170 |
# re-compute da-pose rot_mat for ECON
|
171 |
rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]]
|
172 |
econ_da_verts = torch.tensor(econ_da.vertices).float()
|
173 |
+
econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat(
|
174 |
+
[econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]], dim=1
|
175 |
+
).unsqueeze(-1)
|
176 |
econ_cano_verts = econ_cano_verts[:, :3, 0].double()
|
177 |
|
178 |
# ----------------------------------------------------
|
|
|
188 |
posedirs=econ_posedirs,
|
189 |
J_regressor=econ_J_regressor,
|
190 |
parents=smpl_model.parents,
|
191 |
+
lbs_weights=econ_lbs_weights
|
192 |
+
)
|
193 |
|
194 |
econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces)
|
195 |
+
econ_pose.export(f"{prefix}_econ_pose.obj")
|
apps/infer.py
CHANGED
@@ -34,7 +34,8 @@ from apps.IFGeo import IFGeo
|
|
34 |
from pytorch3d.ops import SubdivideMeshes
|
35 |
from lib.common.config import cfg
|
36 |
from lib.common.render import query_color
|
37 |
-
from lib.common.train_util import init_loss,
|
|
|
38 |
from lib.common.BNI import BNI
|
39 |
from lib.common.BNI_utils import save_normal_tensor
|
40 |
from lib.dataset.TestDataset import TestDataset
|
@@ -68,20 +69,25 @@ if __name__ == "__main__":
|
|
68 |
device = torch.device(f"cuda:{args.gpu_device}")
|
69 |
|
70 |
# setting for testing on in-the-wild images
|
71 |
-
cfg_show_list = [
|
|
|
|
|
|
|
72 |
|
73 |
cfg.merge_from_list(cfg_show_list)
|
74 |
cfg.freeze()
|
75 |
|
76 |
-
# load model
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
|
86 |
# SMPLX object
|
87 |
SMPLX_object = SMPLX()
|
@@ -89,16 +95,24 @@ if __name__ == "__main__":
|
|
89 |
dataset_param = {
|
90 |
"image_dir": args.in_dir,
|
91 |
"seg_dir": args.seg_dir,
|
92 |
-
"use_seg": True,
|
93 |
-
"hps_type": cfg.bni.hps_type,
|
94 |
"vol_res": cfg.vol_res,
|
95 |
"single": args.multi,
|
96 |
}
|
97 |
|
98 |
if cfg.bni.use_ifnet:
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
else:
|
101 |
-
print(colored("
|
102 |
|
103 |
dataset = TestDataset(dataset_param, device)
|
104 |
|
@@ -125,13 +139,17 @@ if __name__ == "__main__":
|
|
125 |
# 2. SMPL params (xxx_smpl.npy)
|
126 |
# 3. d-BiNI surfaces (xxx_BNI.obj)
|
127 |
# 4. seperate face/hand mesh (xxx_hand/face.obj)
|
128 |
-
# 5. full shape impainted by IF-Nets
|
129 |
# 6. sideded or occluded parts (xxx_side.obj)
|
130 |
# 7. final reconstructed clothed human (xxx_full.obj)
|
131 |
|
132 |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
133 |
|
134 |
-
in_tensor = {
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# The optimizer and variables
|
137 |
optimed_pose = data["body_pose"].requires_grad_(True)
|
@@ -139,7 +157,9 @@ if __name__ == "__main__":
|
|
139 |
optimed_betas = data["betas"].requires_grad_(True)
|
140 |
optimed_orient = data["global_orient"].requires_grad_(True)
|
141 |
|
142 |
-
optimizer_smpl = torch.optim.Adam(
|
|
|
|
|
143 |
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
144 |
optimizer_smpl,
|
145 |
mode="min",
|
@@ -156,10 +176,12 @@ if __name__ == "__main__":
|
|
156 |
|
157 |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
158 |
|
|
|
159 |
if osp.exists(smpl_path):
|
160 |
|
161 |
smpl_verts_lst = []
|
162 |
smpl_faces_lst = []
|
|
|
163 |
for idx in range(N_body):
|
164 |
|
165 |
smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
|
@@ -173,10 +195,12 @@ if __name__ == "__main__":
|
|
173 |
batch_smpl_faces = torch.stack(smpl_faces_lst)
|
174 |
|
175 |
# render optimized mesh as normal [-1,1]
|
176 |
-
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
|
|
|
|
|
177 |
|
178 |
with torch.no_grad():
|
179 |
-
in_tensor["normal_F"], in_tensor["normal_B"] =
|
180 |
|
181 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
182 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
@@ -194,8 +218,10 @@ if __name__ == "__main__":
|
|
194 |
N_body, N_pose = optimed_pose.shape[:2]
|
195 |
|
196 |
# 6d_rot to rot_mat
|
197 |
-
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,
|
198 |
-
|
|
|
|
|
199 |
|
200 |
smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
|
201 |
shape_params=optimed_betas,
|
@@ -208,11 +234,16 @@ if __name__ == "__main__":
|
|
208 |
)
|
209 |
|
210 |
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
|
211 |
-
smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor(
|
|
|
|
|
212 |
|
213 |
# landmark errors
|
214 |
-
smpl_joints_3d = (
|
215 |
-
|
|
|
|
|
|
|
216 |
|
217 |
ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
|
218 |
ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
|
@@ -227,7 +258,7 @@ if __name__ == "__main__":
|
|
227 |
T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
|
228 |
|
229 |
with torch.no_grad():
|
230 |
-
in_tensor["normal_F"], in_tensor["normal_B"] =
|
231 |
|
232 |
diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
|
233 |
diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
|
@@ -249,25 +280,37 @@ if __name__ == "__main__":
|
|
249 |
|
250 |
# BUG: PyTorch3D silhouette renderer generates dilated mask
|
251 |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
|
252 |
-
smpl_arr_fake = torch.cat(
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
-
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)
|
|
|
256 |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
|
257 |
body_overlap_flag = body_overlap < cfg.body_overlap_thres
|
258 |
|
259 |
-
losses["normal"]["value"] = (
|
|
|
|
|
|
|
260 |
|
261 |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
|
262 |
occluded_idx = torch.where(body_overlap_flag)[0]
|
263 |
ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
|
264 |
-
losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) *
|
|
|
265 |
|
266 |
# Weighted sum of the losses
|
267 |
smpl_loss = 0.0
|
268 |
-
pbar_desc = "Body Fitting
|
269 |
for k in ["normal", "silhouette", "joint"]:
|
270 |
-
per_loop_loss = (
|
|
|
|
|
271 |
pbar_desc += f"{k}: {per_loop_loss:.3f} | "
|
272 |
smpl_loss += per_loop_loss
|
273 |
pbar_desc += f"Total: {smpl_loss:.3f}"
|
@@ -279,19 +322,25 @@ if __name__ == "__main__":
|
|
279 |
# save intermediate results / vis_freq and final_step
|
280 |
if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1):
|
281 |
|
282 |
-
per_loop_lst.extend(
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
smpl_loss.backward()
|
297 |
optimizer_smpl.step()
|
@@ -304,14 +353,21 @@ if __name__ == "__main__":
|
|
304 |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
|
305 |
torchvision.utils.save_image(
|
306 |
torch.cat(
|
307 |
-
[
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
|
311 |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
|
312 |
|
313 |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
|
314 |
-
torchvision.utils.save_image(
|
|
|
|
|
315 |
|
316 |
smpl_obj_lst = []
|
317 |
|
@@ -329,15 +385,28 @@ if __name__ == "__main__":
|
|
329 |
if not osp.exists(smpl_obj_path):
|
330 |
smpl_obj.export(smpl_obj_path)
|
331 |
smpl_info = {
|
332 |
-
"betas":
|
333 |
-
|
334 |
-
"
|
335 |
-
|
336 |
-
|
337 |
-
"
|
338 |
-
|
339 |
-
|
340 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
}
|
342 |
np.save(
|
343 |
smpl_obj_path.replace(".obj", ".npy"),
|
@@ -359,10 +428,13 @@ if __name__ == "__main__":
|
|
359 |
|
360 |
per_data_lst = []
|
361 |
|
362 |
-
batch_smpl_verts = in_tensor["smpl_verts"].detach(
|
|
|
363 |
batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
|
364 |
|
365 |
-
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
|
|
|
|
|
366 |
|
367 |
per_loop_lst = []
|
368 |
|
@@ -389,7 +461,13 @@ if __name__ == "__main__":
|
|
389 |
)
|
390 |
|
391 |
# BNI process
|
392 |
-
BNI_object = BNI(
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
BNI_object.extract_surface(False)
|
395 |
|
@@ -406,29 +484,40 @@ if __name__ == "__main__":
|
|
406 |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
|
407 |
|
408 |
# mesh completion via IF-net
|
409 |
-
in_tensor.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
|
412 |
0,
|
413 |
] * 3, scale=2.0).data.transpose(2, 1, 0)
|
414 |
occupancies = np.flip(occupancies, axis=1)
|
415 |
|
416 |
-
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
|
|
|
417 |
|
418 |
with torch.no_grad():
|
419 |
-
sdf =
|
420 |
-
verts_IF, faces_IF =
|
421 |
|
422 |
-
if
|
423 |
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
|
424 |
|
425 |
side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
|
426 |
-
side_mesh =
|
427 |
|
428 |
else:
|
429 |
side_mesh = apply_vertex_mask(
|
430 |
side_mesh,
|
431 |
-
(
|
|
|
|
|
|
|
432 |
)
|
433 |
|
434 |
#register side_mesh to BNI surfaces
|
@@ -448,7 +537,9 @@ if __name__ == "__main__":
|
|
448 |
# 3. remove eyeball faces
|
449 |
|
450 |
# export intermediate meshes
|
451 |
-
BNI_object.F_B_trimesh.export(
|
|
|
|
|
452 |
full_lst = []
|
453 |
|
454 |
if "face" in cfg.bni.use_smpl:
|
@@ -458,37 +549,63 @@ if __name__ == "__main__":
|
|
458 |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
|
459 |
|
460 |
# remove face neighbor triangles
|
461 |
-
BNI_object.F_B_trimesh = part_removal(
|
462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
|
464 |
full_lst += [face_mesh]
|
465 |
|
466 |
if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
|
467 |
|
468 |
-
hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0],)
|
469 |
if data['hands_visibility'][idx][0]:
|
470 |
-
hand_mask.index_fill_(
|
|
|
|
|
471 |
if data['hands_visibility'][idx][1]:
|
472 |
-
hand_mask.index_fill_(
|
|
|
|
|
473 |
|
474 |
# only hands
|
475 |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
|
476 |
|
477 |
# remove hand neighbor triangles
|
478 |
-
BNI_object.F_B_trimesh = part_removal(
|
479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
|
481 |
full_lst += [hand_mesh]
|
482 |
|
483 |
full_lst += [BNI_object.F_B_trimesh]
|
484 |
|
485 |
# initial side_mesh could be SMPLX or IF-net
|
486 |
-
side_mesh = part_removal(
|
|
|
|
|
487 |
|
488 |
full_lst += [side_mesh]
|
489 |
|
490 |
# # export intermediate meshes
|
491 |
-
BNI_object.F_B_trimesh.export(
|
|
|
|
|
492 |
side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
|
493 |
|
494 |
if cfg.bni.use_poisson:
|
@@ -505,15 +622,22 @@ if __name__ == "__main__":
|
|
505 |
rotate_recon_lst = dataset.render.get_image(cam_type="four")
|
506 |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
|
507 |
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
|
518 |
# for video rendering
|
519 |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
|
|
|
34 |
from pytorch3d.ops import SubdivideMeshes
|
35 |
from lib.common.config import cfg
|
36 |
from lib.common.render import query_color
|
37 |
+
from lib.common.train_util import init_loss, Format
|
38 |
+
from lib.common.imutils import blend_rgb_norm
|
39 |
from lib.common.BNI import BNI
|
40 |
from lib.common.BNI_utils import save_normal_tensor
|
41 |
from lib.dataset.TestDataset import TestDataset
|
|
|
69 |
device = torch.device(f"cuda:{args.gpu_device}")
|
70 |
|
71 |
# setting for testing on in-the-wild images
|
72 |
+
cfg_show_list = [
|
73 |
+
"test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True,
|
74 |
+
"batch_size", 1
|
75 |
+
]
|
76 |
|
77 |
cfg.merge_from_list(cfg_show_list)
|
78 |
cfg.freeze()
|
79 |
|
80 |
+
# load normal model
|
81 |
+
normal_net = Normal.load_from_checkpoint(
|
82 |
+
cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False
|
83 |
+
)
|
84 |
+
normal_net = normal_net.to(device)
|
85 |
+
normal_net.netG.eval()
|
86 |
+
print(
|
87 |
+
colored(
|
88 |
+
f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green"
|
89 |
+
)
|
90 |
+
)
|
91 |
|
92 |
# SMPLX object
|
93 |
SMPLX_object = SMPLX()
|
|
|
95 |
dataset_param = {
|
96 |
"image_dir": args.in_dir,
|
97 |
"seg_dir": args.seg_dir,
|
98 |
+
"use_seg": True, # w/ or w/o segmentation
|
99 |
+
"hps_type": cfg.bni.hps_type, # pymafx/pixie
|
100 |
"vol_res": cfg.vol_res,
|
101 |
"single": args.multi,
|
102 |
}
|
103 |
|
104 |
if cfg.bni.use_ifnet:
|
105 |
+
# load IFGeo model
|
106 |
+
ifnet = IFGeo.load_from_checkpoint(
|
107 |
+
cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False
|
108 |
+
)
|
109 |
+
ifnet = ifnet.to(device)
|
110 |
+
ifnet.netG.eval()
|
111 |
+
|
112 |
+
print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green"))
|
113 |
+
print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green"))
|
114 |
else:
|
115 |
+
print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green"))
|
116 |
|
117 |
dataset = TestDataset(dataset_param, device)
|
118 |
|
|
|
139 |
# 2. SMPL params (xxx_smpl.npy)
|
140 |
# 3. d-BiNI surfaces (xxx_BNI.obj)
|
141 |
# 4. seperate face/hand mesh (xxx_hand/face.obj)
|
142 |
+
# 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj)
|
143 |
# 6. sideded or occluded parts (xxx_side.obj)
|
144 |
# 7. final reconstructed clothed human (xxx_full.obj)
|
145 |
|
146 |
os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
147 |
|
148 |
+
in_tensor = {
|
149 |
+
"smpl_faces": data["smpl_faces"],
|
150 |
+
"image": data["img_icon"].to(device),
|
151 |
+
"mask": data["img_mask"].to(device)
|
152 |
+
}
|
153 |
|
154 |
# The optimizer and variables
|
155 |
optimed_pose = data["body_pose"].requires_grad_(True)
|
|
|
157 |
optimed_betas = data["betas"].requires_grad_(True)
|
158 |
optimed_orient = data["global_orient"].requires_grad_(True)
|
159 |
|
160 |
+
optimizer_smpl = torch.optim.Adam(
|
161 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True
|
162 |
+
)
|
163 |
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
164 |
optimizer_smpl,
|
165 |
mode="min",
|
|
|
176 |
|
177 |
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
178 |
|
179 |
+
# remove this line if you change the loop_smpl and obtain different SMPL-X fits
|
180 |
if osp.exists(smpl_path):
|
181 |
|
182 |
smpl_verts_lst = []
|
183 |
smpl_faces_lst = []
|
184 |
+
|
185 |
for idx in range(N_body):
|
186 |
|
187 |
smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj"
|
|
|
195 |
batch_smpl_faces = torch.stack(smpl_faces_lst)
|
196 |
|
197 |
# render optimized mesh as normal [-1,1]
|
198 |
+
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
|
199 |
+
batch_smpl_verts, batch_smpl_faces
|
200 |
+
)
|
201 |
|
202 |
with torch.no_grad():
|
203 |
+
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
|
204 |
|
205 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
206 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
|
|
218 |
N_body, N_pose = optimed_pose.shape[:2]
|
219 |
|
220 |
# 6d_rot to rot_mat
|
221 |
+
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,
|
222 |
+
6)).view(N_body, 1, 3, 3)
|
223 |
+
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,
|
224 |
+
6)).view(N_body, N_pose, 3, 3)
|
225 |
|
226 |
smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model(
|
227 |
shape_params=optimed_betas,
|
|
|
234 |
)
|
235 |
|
236 |
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
|
237 |
+
smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor(
|
238 |
+
[1.0, 1.0, -1.0]
|
239 |
+
).to(device)
|
240 |
|
241 |
# landmark errors
|
242 |
+
smpl_joints_3d = (
|
243 |
+
smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0
|
244 |
+
) * 0.5
|
245 |
+
in_tensor["smpl_joint"] = smpl_joints[:,
|
246 |
+
dataset.smpl_data.smpl_joint_ids_24_pixie, :]
|
247 |
|
248 |
ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device)
|
249 |
ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device)
|
|
|
258 |
T_mask_F, T_mask_B = dataset.render.get_image(type="mask")
|
259 |
|
260 |
with torch.no_grad():
|
261 |
+
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
|
262 |
|
263 |
diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"])
|
264 |
diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"])
|
|
|
280 |
|
281 |
# BUG: PyTorch3D silhouette renderer generates dilated mask
|
282 |
bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
|
283 |
+
smpl_arr_fake = torch.cat(
|
284 |
+
[
|
285 |
+
in_tensor["T_normal_F"][:, 0].ne(bg_value).float(),
|
286 |
+
in_tensor["T_normal_B"][:, 0].ne(bg_value).float()
|
287 |
+
],
|
288 |
+
dim=-1
|
289 |
+
)
|
290 |
|
291 |
+
body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)
|
292 |
+
).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
|
293 |
body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
|
294 |
body_overlap_flag = body_overlap < cfg.body_overlap_thres
|
295 |
|
296 |
+
losses["normal"]["value"] = (
|
297 |
+
diff_F_smpl * body_overlap_mask[..., :512] +
|
298 |
+
diff_B_smpl * body_overlap_mask[..., 512:]
|
299 |
+
).mean() / 2.0
|
300 |
|
301 |
losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
|
302 |
occluded_idx = torch.where(body_overlap_flag)[0]
|
303 |
ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95
|
304 |
+
losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) *
|
305 |
+
ghum_conf).mean(dim=1)
|
306 |
|
307 |
# Weighted sum of the losses
|
308 |
smpl_loss = 0.0
|
309 |
+
pbar_desc = "Body Fitting -- "
|
310 |
for k in ["normal", "silhouette", "joint"]:
|
311 |
+
per_loop_loss = (
|
312 |
+
losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)
|
313 |
+
).mean()
|
314 |
pbar_desc += f"{k}: {per_loop_loss:.3f} | "
|
315 |
smpl_loss += per_loop_loss
|
316 |
pbar_desc += f"Total: {smpl_loss:.3f}"
|
|
|
322 |
# save intermediate results / vis_freq and final_step
|
323 |
if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1):
|
324 |
|
325 |
+
per_loop_lst.extend(
|
326 |
+
[
|
327 |
+
in_tensor["image"],
|
328 |
+
in_tensor["T_normal_F"],
|
329 |
+
in_tensor["normal_F"],
|
330 |
+
diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1),
|
331 |
+
]
|
332 |
+
)
|
333 |
+
per_loop_lst.extend(
|
334 |
+
[
|
335 |
+
in_tensor["image"],
|
336 |
+
in_tensor["T_normal_B"],
|
337 |
+
in_tensor["normal_B"],
|
338 |
+
diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1),
|
339 |
+
]
|
340 |
+
)
|
341 |
+
per_data_lst.append(
|
342 |
+
get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl")
|
343 |
+
)
|
344 |
|
345 |
smpl_loss.backward()
|
346 |
optimizer_smpl.step()
|
|
|
353 |
img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
|
354 |
torchvision.utils.save_image(
|
355 |
torch.cat(
|
356 |
+
[
|
357 |
+
data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
|
358 |
+
(in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
|
359 |
+
],
|
360 |
+
dim=3
|
361 |
+
), img_crop_path
|
362 |
+
)
|
363 |
|
364 |
rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
|
365 |
rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
|
366 |
|
367 |
img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
|
368 |
+
torchvision.utils.save_image(
|
369 |
+
torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., img_overlap_path
|
370 |
+
)
|
371 |
|
372 |
smpl_obj_lst = []
|
373 |
|
|
|
385 |
if not osp.exists(smpl_obj_path):
|
386 |
smpl_obj.export(smpl_obj_path)
|
387 |
smpl_info = {
|
388 |
+
"betas":
|
389 |
+
optimed_betas[idx].detach().cpu().unsqueeze(0),
|
390 |
+
"body_pose":
|
391 |
+
rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()
|
392 |
+
).cpu().unsqueeze(0),
|
393 |
+
"global_orient":
|
394 |
+
rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()
|
395 |
+
).cpu().unsqueeze(0),
|
396 |
+
"transl":
|
397 |
+
optimed_trans[idx].detach().cpu(),
|
398 |
+
"expression":
|
399 |
+
data["exp"][idx].cpu().unsqueeze(0),
|
400 |
+
"jaw_pose":
|
401 |
+
rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0),
|
402 |
+
"left_hand_pose":
|
403 |
+
rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]
|
404 |
+
).cpu().unsqueeze(0),
|
405 |
+
"right_hand_pose":
|
406 |
+
rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]
|
407 |
+
).cpu().unsqueeze(0),
|
408 |
+
"scale":
|
409 |
+
data["scale"][idx].cpu(),
|
410 |
}
|
411 |
np.save(
|
412 |
smpl_obj_path.replace(".obj", ".npy"),
|
|
|
428 |
|
429 |
per_data_lst = []
|
430 |
|
431 |
+
batch_smpl_verts = in_tensor["smpl_verts"].detach(
|
432 |
+
) * torch.tensor([1.0, -1.0, 1.0], device=device)
|
433 |
batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]]
|
434 |
|
435 |
+
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
|
436 |
+
batch_smpl_verts, batch_smpl_faces
|
437 |
+
)
|
438 |
|
439 |
per_loop_lst = []
|
440 |
|
|
|
461 |
)
|
462 |
|
463 |
# BNI process
|
464 |
+
BNI_object = BNI(
|
465 |
+
dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
|
466 |
+
name=data["name"],
|
467 |
+
BNI_dict=BNI_dict,
|
468 |
+
cfg=cfg.bni,
|
469 |
+
device=device
|
470 |
+
)
|
471 |
|
472 |
BNI_object.extract_surface(False)
|
473 |
|
|
|
484 |
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
|
485 |
|
486 |
# mesh completion via IF-net
|
487 |
+
in_tensor.update(
|
488 |
+
dataset.depth_to_voxel(
|
489 |
+
{
|
490 |
+
"depth_F": BNI_object.F_depth.unsqueeze(0),
|
491 |
+
"depth_B": BNI_object.B_depth.unsqueeze(0)
|
492 |
+
}
|
493 |
+
)
|
494 |
+
)
|
495 |
|
496 |
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
|
497 |
0,
|
498 |
] * 3, scale=2.0).data.transpose(2, 1, 0)
|
499 |
occupancies = np.flip(occupancies, axis=1)
|
500 |
|
501 |
+
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
|
502 |
+
).float().unsqueeze(0).to(device)
|
503 |
|
504 |
with torch.no_grad():
|
505 |
+
sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor)
|
506 |
+
verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf)
|
507 |
|
508 |
+
if ifnet.clean_mesh_flag:
|
509 |
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
|
510 |
|
511 |
side_mesh = trimesh.Trimesh(verts_IF, faces_IF)
|
512 |
+
side_mesh = remesh_laplacian(side_mesh, side_mesh_path)
|
513 |
|
514 |
else:
|
515 |
side_mesh = apply_vertex_mask(
|
516 |
side_mesh,
|
517 |
+
(
|
518 |
+
SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
519 |
+
SMPLX_object.eyeball_vertex_mask
|
520 |
+
).eq(0).float(),
|
521 |
)
|
522 |
|
523 |
#register side_mesh to BNI surfaces
|
|
|
537 |
# 3. remove eyeball faces
|
538 |
|
539 |
# export intermediate meshes
|
540 |
+
BNI_object.F_B_trimesh.export(
|
541 |
+
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj"
|
542 |
+
)
|
543 |
full_lst = []
|
544 |
|
545 |
if "face" in cfg.bni.use_smpl:
|
|
|
549 |
face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
|
550 |
|
551 |
# remove face neighbor triangles
|
552 |
+
BNI_object.F_B_trimesh = part_removal(
|
553 |
+
BNI_object.F_B_trimesh,
|
554 |
+
face_mesh,
|
555 |
+
cfg.bni.face_thres,
|
556 |
+
device,
|
557 |
+
smplx_mesh,
|
558 |
+
region="face"
|
559 |
+
)
|
560 |
+
side_mesh = part_removal(
|
561 |
+
side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face"
|
562 |
+
)
|
563 |
face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
|
564 |
full_lst += [face_mesh]
|
565 |
|
566 |
if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]):
|
567 |
|
568 |
+
hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], )
|
569 |
if data['hands_visibility'][idx][0]:
|
570 |
+
hand_mask.index_fill_(
|
571 |
+
0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0
|
572 |
+
)
|
573 |
if data['hands_visibility'][idx][1]:
|
574 |
+
hand_mask.index_fill_(
|
575 |
+
0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0
|
576 |
+
)
|
577 |
|
578 |
# only hands
|
579 |
hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
|
580 |
|
581 |
# remove hand neighbor triangles
|
582 |
+
BNI_object.F_B_trimesh = part_removal(
|
583 |
+
BNI_object.F_B_trimesh,
|
584 |
+
hand_mesh,
|
585 |
+
cfg.bni.hand_thres,
|
586 |
+
device,
|
587 |
+
smplx_mesh,
|
588 |
+
region="hand"
|
589 |
+
)
|
590 |
+
side_mesh = part_removal(
|
591 |
+
side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand"
|
592 |
+
)
|
593 |
hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
|
594 |
full_lst += [hand_mesh]
|
595 |
|
596 |
full_lst += [BNI_object.F_B_trimesh]
|
597 |
|
598 |
# initial side_mesh could be SMPLX or IF-net
|
599 |
+
side_mesh = part_removal(
|
600 |
+
side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False
|
601 |
+
)
|
602 |
|
603 |
full_lst += [side_mesh]
|
604 |
|
605 |
# # export intermediate meshes
|
606 |
+
BNI_object.F_B_trimesh.export(
|
607 |
+
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj"
|
608 |
+
)
|
609 |
side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj")
|
610 |
|
611 |
if cfg.bni.use_poisson:
|
|
|
622 |
rotate_recon_lst = dataset.render.get_image(cam_type="four")
|
623 |
per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
|
624 |
|
625 |
+
if cfg.bni.texture_src == 'image':
|
626 |
+
|
627 |
+
# coloring the final mesh (front: RGB pixels, back: normal colors)
|
628 |
+
final_colors = query_color(
|
629 |
+
torch.tensor(final_mesh.vertices).float(),
|
630 |
+
torch.tensor(final_mesh.faces).long(),
|
631 |
+
in_tensor["image"][idx:idx + 1],
|
632 |
+
device=device,
|
633 |
+
)
|
634 |
+
final_mesh.visual.vertex_colors = final_colors
|
635 |
+
final_mesh.export(final_path)
|
636 |
+
|
637 |
+
elif cfg.bni.texture_src == 'SD':
|
638 |
+
|
639 |
+
# !TODO: add texture from Stable Diffusion
|
640 |
+
pass
|
641 |
|
642 |
# for video rendering
|
643 |
in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
|
apps/multi_render.py
CHANGED
@@ -20,6 +20,4 @@ faces_lst = in_tensor["body_faces"] + in_tensor["BNI_faces"]
|
|
20 |
|
21 |
# self-rotated video
|
22 |
render.load_meshes(verts_lst, faces_lst)
|
23 |
-
render.get_rendered_video_multi(
|
24 |
-
in_tensor,
|
25 |
-
f"{root}/{args.name}_cloth.mp4")
|
|
|
20 |
|
21 |
# self-rotated video
|
22 |
render.load_meshes(verts_lst, faces_lst)
|
23 |
+
render.get_rendered_video_multi(in_tensor, f"{root}/{args.name}_cloth.mp4")
|
|
|
|
configs/econ.yaml
CHANGED
@@ -35,3 +35,4 @@ bni:
|
|
35 |
face_thres: 6e-2
|
36 |
thickness: 0.02
|
37 |
hps_type: "pixie"
|
|
|
|
35 |
face_thres: 6e-2
|
36 |
thickness: 0.02
|
37 |
hps_type: "pixie"
|
38 |
+
texture_src: "SD"
|
docs/installation.md
CHANGED
@@ -9,12 +9,11 @@ cd ECON
|
|
9 |
|
10 |
## Environment
|
11 |
|
12 |
-
- Ubuntu 20 / 18
|
13 |
-
- GCC=7 (required by [pypoisson](https://github.com/mmolero/pypoisson/issues/13))
|
14 |
- **CUDA=11.4, GPU Memory > 12GB**
|
15 |
- Python = 3.8
|
16 |
- PyTorch >= 1.13.0 (official [Get Started](https://pytorch.org/get-started/locally/))
|
17 |
-
-
|
18 |
- PyTorch3D (official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), recommend [install-from-local-clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone))
|
19 |
|
20 |
```bash
|
|
|
9 |
|
10 |
## Environment
|
11 |
|
12 |
+
- Ubuntu 20 / 18, (Windows as well, see [issue#7](https://github.com/YuliangXiu/ECON/issues/7))
|
|
|
13 |
- **CUDA=11.4, GPU Memory > 12GB**
|
14 |
- Python = 3.8
|
15 |
- PyTorch >= 1.13.0 (official [Get Started](https://pytorch.org/get-started/locally/))
|
16 |
+
- Cupy >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi))
|
17 |
- PyTorch3D (official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), recommend [install-from-local-clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone))
|
18 |
|
19 |
```bash
|
lib/common/BNI.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
from lib.common.BNI_utils import (
|
2 |
-
|
|
|
3 |
|
4 |
import torch
|
5 |
import trimesh
|
6 |
|
7 |
|
8 |
class BNI:
|
9 |
-
|
10 |
def __init__(self, dir_path, name, BNI_dict, cfg, device):
|
11 |
|
12 |
self.scale = 256.0
|
@@ -64,22 +64,20 @@ class BNI:
|
|
64 |
|
65 |
F_B_verts = torch.cat((F_verts, B_verts), dim=0)
|
66 |
F_B_faces = torch.cat(
|
67 |
-
(bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0
|
|
|
68 |
|
69 |
-
self.F_B_trimesh = trimesh.Trimesh(
|
70 |
-
|
71 |
-
|
72 |
-
maintain_order=True)
|
73 |
|
74 |
-
self.F_trimesh = trimesh.Trimesh(
|
75 |
-
|
76 |
-
|
77 |
-
maintain_order=True)
|
78 |
|
79 |
-
self.B_trimesh = trimesh.Trimesh(
|
80 |
-
|
81 |
-
|
82 |
-
maintain_order=True)
|
83 |
|
84 |
|
85 |
if __name__ == "__main__":
|
@@ -93,16 +91,18 @@ if __name__ == "__main__":
|
|
93 |
bni_dict = np.load(npy_file, allow_pickle=True).item()
|
94 |
|
95 |
default_cfg = {'k': 2, 'lambda1': 1e-4, 'boundary_consist': 1e-6}
|
96 |
-
|
97 |
# for k in [1, 2, 4, 10, 100]:
|
98 |
# default_cfg['k'] = k
|
99 |
# for k in [1e-8, 1e-4, 1e-2, 1e-1, 1]:
|
100 |
-
|
101 |
# for k in [1e-4, 1e-2, 0]:
|
102 |
-
|
103 |
-
|
104 |
-
bni_object = BNI(
|
105 |
-
|
|
|
|
|
106 |
|
107 |
bni_object.extract_surface()
|
108 |
bni_object.F_trimesh.export(osp.join(osp.dirname(npy_file), "F.obj"))
|
|
|
1 |
+
from lib.common.BNI_utils import (
|
2 |
+
verts_inverse_transform, depth_inverse_transform, double_side_bilateral_normal_integration
|
3 |
+
)
|
4 |
|
5 |
import torch
|
6 |
import trimesh
|
7 |
|
8 |
|
9 |
class BNI:
|
|
|
10 |
def __init__(self, dir_path, name, BNI_dict, cfg, device):
|
11 |
|
12 |
self.scale = 256.0
|
|
|
64 |
|
65 |
F_B_verts = torch.cat((F_verts, B_verts), dim=0)
|
66 |
F_B_faces = torch.cat(
|
67 |
+
(bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0
|
68 |
+
)
|
69 |
|
70 |
+
self.F_B_trimesh = trimesh.Trimesh(
|
71 |
+
F_B_verts.float(), F_B_faces.long(), process=False, maintain_order=True
|
72 |
+
)
|
|
|
73 |
|
74 |
+
self.F_trimesh = trimesh.Trimesh(
|
75 |
+
F_verts.float(), bni_result["F_faces"].long(), process=False, maintain_order=True
|
76 |
+
)
|
|
|
77 |
|
78 |
+
self.B_trimesh = trimesh.Trimesh(
|
79 |
+
B_verts.float(), bni_result["B_faces"].long(), process=False, maintain_order=True
|
80 |
+
)
|
|
|
81 |
|
82 |
|
83 |
if __name__ == "__main__":
|
|
|
91 |
bni_dict = np.load(npy_file, allow_pickle=True).item()
|
92 |
|
93 |
default_cfg = {'k': 2, 'lambda1': 1e-4, 'boundary_consist': 1e-6}
|
94 |
+
|
95 |
# for k in [1, 2, 4, 10, 100]:
|
96 |
# default_cfg['k'] = k
|
97 |
# for k in [1e-8, 1e-4, 1e-2, 1e-1, 1]:
|
98 |
+
# default_cfg['lambda1'] = k
|
99 |
# for k in [1e-4, 1e-2, 0]:
|
100 |
+
# default_cfg['boundary_consist'] = k
|
101 |
+
|
102 |
+
bni_object = BNI(
|
103 |
+
osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg,
|
104 |
+
torch.device('cuda:0')
|
105 |
+
)
|
106 |
|
107 |
bni_object.extract_surface()
|
108 |
bni_object.F_trimesh.export(osp.join(osp.dirname(npy_file), "F.obj"))
|
lib/common/BNI_utils.py
CHANGED
@@ -53,8 +53,9 @@ def find_contour(mask, method='all'):
|
|
53 |
|
54 |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
55 |
else:
|
56 |
-
contours, _ = cv2.findContours(
|
57 |
-
|
|
|
58 |
|
59 |
contour_cloth = np.array(find_max_list(contours))[:, 0, :]
|
60 |
|
@@ -67,16 +68,19 @@ def mean_value_cordinates(inner_pts, contour_pts):
|
|
67 |
body_edges_c = np.roll(body_edges_a, shift=-1, axis=1)
|
68 |
body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1))
|
69 |
|
70 |
-
body_edges = np.concatenate(
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
75 |
|
76 |
body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 -
|
77 |
body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1])
|
78 |
body_tan_half = np.sqrt(
|
79 |
-
(1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.)
|
|
|
80 |
|
81 |
w = (body_tan_half + np.roll(body_tan_half, shift=1, axis=1)) / body_edges_a
|
82 |
w /= w.sum(axis=1, keepdims=True)
|
@@ -97,16 +101,18 @@ def dispCorres(img_size, contour1, contour2, phi, dir_path):
|
|
97 |
contour2 = contour2[None, :, None, :].astype(np.int32)
|
98 |
|
99 |
disp = np.zeros((img_size, img_size, 3), dtype=np.uint8)
|
100 |
-
cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1)
|
101 |
-
cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1)
|
102 |
|
103 |
-
for i in range(contour1.shape[1]):
|
104 |
# cv2.circle(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), 1,
|
105 |
# (255, 0, 0), -1)
|
106 |
corresPoint = contour2[0, phi[i], 0]
|
107 |
# cv2.circle(disp, (corresPoint[0], corresPoint[1]), 1, (0, 255, 0), -1)
|
108 |
-
cv2.line(
|
109 |
-
|
|
|
|
|
110 |
|
111 |
cv2.imwrite(osp.join(dir_path, "corres.png"), disp)
|
112 |
|
@@ -162,7 +168,8 @@ def verts_transform(t, depth_scale):
|
|
162 |
t_copy *= depth_scale * 0.5
|
163 |
t_copy += depth_scale * 0.5
|
164 |
t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor(
|
165 |
-
[0.0, 0.0, depth_scale]
|
|
|
166 |
|
167 |
return t_copy
|
168 |
|
@@ -328,19 +335,22 @@ def construct_facets_from(mask):
|
|
328 |
facet_move_top_mask = move_top(mask)
|
329 |
facet_move_left_mask = move_left(mask)
|
330 |
facet_move_top_left_mask = move_top_left(mask)
|
331 |
-
facet_top_left_mask = (
|
332 |
-
|
|
|
333 |
facet_top_right_mask = move_right(facet_top_left_mask)
|
334 |
facet_bottom_left_mask = move_bottom(facet_top_left_mask)
|
335 |
facet_bottom_right_mask = move_bottom_right(facet_top_left_mask)
|
336 |
|
337 |
-
return cp.hstack(
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
344 |
|
345 |
|
346 |
def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
|
@@ -364,8 +374,8 @@ def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
|
|
364 |
u[..., 0] = xx
|
365 |
u[..., 1] = yy
|
366 |
u[..., 2] = 1
|
367 |
-
u = u[mask].T
|
368 |
-
vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis]
|
369 |
|
370 |
return vertices
|
371 |
|
@@ -374,7 +384,6 @@ def sigmoid(x, k=1):
|
|
374 |
return 1 / (1 + cp.exp(-k * x))
|
375 |
|
376 |
|
377 |
-
|
378 |
def boundary_excluded_mask(mask):
|
379 |
top_mask = cp.pad(mask, ((1, 0), (0, 0)), "constant", constant_values=0)[:-1, :]
|
380 |
bottom_mask = cp.pad(mask, ((0, 1), (0, 0)), "constant", constant_values=0)[1:, :]
|
@@ -410,22 +419,24 @@ def create_boundary_matrix(mask):
|
|
410 |
return B, B_full
|
411 |
|
412 |
|
413 |
-
def double_side_bilateral_normal_integration(
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
|
|
|
|
429 |
|
430 |
# To avoid confusion, we list the coordinate systems in this code as follows
|
431 |
#
|
@@ -467,14 +478,12 @@ def double_side_bilateral_normal_integration(normal_front,
|
|
467 |
del normal_map_back
|
468 |
|
469 |
# right, left, top, bottom
|
470 |
-
A3_f, A4_f, A1_f, A2_f = generate_dx_dy(
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
nz_vertical=nz_back,
|
477 |
-
step_size=step_size)
|
478 |
|
479 |
has_left_mask = cp.logical_and(move_right(normal_mask), normal_mask)
|
480 |
has_right_mask = cp.logical_and(move_left(normal_mask), normal_mask)
|
@@ -498,29 +507,25 @@ def double_side_bilateral_normal_integration(normal_front,
|
|
498 |
b_back = cp.concatenate((-nx_back, -nx_back, -ny_back, -ny_back))
|
499 |
|
500 |
# initialization
|
501 |
-
W_front = spdiags(
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
0,
|
508 |
-
4 * num_normals,
|
509 |
-
4 * num_normals,
|
510 |
-
format="csr")
|
511 |
|
512 |
z_front = cp.zeros(num_normals, float)
|
513 |
z_back = cp.zeros(num_normals, float)
|
514 |
z_combined = cp.concatenate((z_front, z_back))
|
515 |
|
516 |
B, B_full = create_boundary_matrix(normal_mask)
|
517 |
-
B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get())
|
518 |
|
519 |
energy_list = []
|
520 |
|
521 |
if depth_mask is not None:
|
522 |
-
depth_mask_flat = depth_mask[normal_mask].astype(bool)
|
523 |
-
z_prior_front = depth_map_front[normal_mask]
|
524 |
z_prior_front[~depth_mask_flat] = 0
|
525 |
z_prior_back = depth_map_back[normal_mask]
|
526 |
z_prior_back[~depth_mask_flat] = 0
|
@@ -554,40 +559,43 @@ def double_side_bilateral_normal_integration(normal_front,
|
|
554 |
vstack((csr_matrix((num_normals, num_normals)), A_mat_back))]) + B_mat
|
555 |
b_vec_combined = cp.concatenate((b_vec_front, b_vec_back))
|
556 |
|
557 |
-
D = spdiags(
|
558 |
-
|
|
|
|
|
559 |
|
560 |
-
z_combined, _ = cg(
|
561 |
-
|
562 |
-
|
563 |
-
x0=z_combined,
|
564 |
-
maxiter=cg_max_iter,
|
565 |
-
tol=cg_tol)
|
566 |
z_front = z_combined[:num_normals]
|
567 |
z_back = z_combined[num_normals:]
|
568 |
-
wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k)
|
569 |
-
wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k)
|
570 |
wu_f[top_boundnary_mask] = 0.5
|
571 |
wu_f[bottom_boundary_mask] = 0.5
|
572 |
wv_f[left_boundary_mask] = 0.5
|
573 |
wv_f[right_boudnary_mask] = 0.5
|
574 |
-
W_front = spdiags(
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
|
|
|
|
582 |
wu_b[top_boundnary_mask] = 0.5
|
583 |
wu_b[bottom_boundary_mask] = 0.5
|
584 |
wv_b[left_boundary_mask] = 0.5
|
585 |
wv_b[right_boudnary_mask] = 0.5
|
586 |
-
W_back = spdiags(
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
|
|
|
|
591 |
|
592 |
energy_old = energy
|
593 |
energy = (A_front_data @ z_front - b_front).T @ W_front @ (A_front_data @ z_front - b_front) + \
|
@@ -603,23 +611,26 @@ def double_side_bilateral_normal_integration(normal_front,
|
|
603 |
if relative_energy < tol:
|
604 |
break
|
605 |
# del A1, A2, A3, A4, nx, ny
|
606 |
-
|
607 |
depth_map_front_est = cp.ones_like(normal_mask, float) * cp.nan
|
608 |
depth_map_front_est[normal_mask] = z_front
|
609 |
|
610 |
depth_map_back_est = cp.ones_like(normal_mask, float) * cp.nan
|
611 |
depth_map_back_est[normal_mask] = z_back
|
612 |
-
|
613 |
# manually cut the intersection
|
614 |
-
normal_mask[depth_map_front_est>=depth_map_back_est] = False
|
615 |
depth_map_front_est[~normal_mask] = cp.nan
|
616 |
depth_map_back_est[~normal_mask] = cp.nan
|
617 |
|
618 |
vertices_front = cp.asnumpy(
|
619 |
-
map_depth_map_to_point_clouds(
|
620 |
-
|
|
|
|
|
621 |
vertices_back = cp.asnumpy(
|
622 |
-
map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size)
|
|
|
623 |
|
624 |
facets_back = cp.asnumpy(construct_facets_from(normal_mask))
|
625 |
|
@@ -656,7 +667,7 @@ def save_normal_tensor(in_tensor, idx, png_path, thickness=0.0):
|
|
656 |
depth_B_arr = depth2arr(in_tensor["depth_B"][idx])
|
657 |
|
658 |
BNI_dict = {}
|
659 |
-
|
660 |
# clothed human
|
661 |
BNI_dict["normal_F"] = normal_F_arr
|
662 |
BNI_dict["normal_B"] = normal_B_arr
|
|
|
53 |
|
54 |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
55 |
else:
|
56 |
+
contours, _ = cv2.findContours(
|
57 |
+
mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
58 |
+
)
|
59 |
|
60 |
contour_cloth = np.array(find_max_list(contours))[:, 0, :]
|
61 |
|
|
|
68 |
body_edges_c = np.roll(body_edges_a, shift=-1, axis=1)
|
69 |
body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1))
|
70 |
|
71 |
+
body_edges = np.concatenate(
|
72 |
+
[
|
73 |
+
body_edges_a[..., None], body_edges_c[..., None],
|
74 |
+
np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts))
|
75 |
+
],
|
76 |
+
axis=-1
|
77 |
+
)
|
78 |
|
79 |
body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 -
|
80 |
body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1])
|
81 |
body_tan_half = np.sqrt(
|
82 |
+
(1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.)
|
83 |
+
)
|
84 |
|
85 |
w = (body_tan_half + np.roll(body_tan_half, shift=1, axis=1)) / body_edges_a
|
86 |
w /= w.sum(axis=1, keepdims=True)
|
|
|
101 |
contour2 = contour2[None, :, None, :].astype(np.int32)
|
102 |
|
103 |
disp = np.zeros((img_size, img_size, 3), dtype=np.uint8)
|
104 |
+
cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green
|
105 |
+
cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue
|
106 |
|
107 |
+
for i in range(contour1.shape[1]): # do not show all the points when display
|
108 |
# cv2.circle(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), 1,
|
109 |
# (255, 0, 0), -1)
|
110 |
corresPoint = contour2[0, phi[i], 0]
|
111 |
# cv2.circle(disp, (corresPoint[0], corresPoint[1]), 1, (0, 255, 0), -1)
|
112 |
+
cv2.line(
|
113 |
+
disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), (corresPoint[0], corresPoint[1]),
|
114 |
+
(255, 255, 255), 1
|
115 |
+
)
|
116 |
|
117 |
cv2.imwrite(osp.join(dir_path, "corres.png"), disp)
|
118 |
|
|
|
168 |
t_copy *= depth_scale * 0.5
|
169 |
t_copy += depth_scale * 0.5
|
170 |
t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor(
|
171 |
+
[0.0, 0.0, depth_scale]
|
172 |
+
)
|
173 |
|
174 |
return t_copy
|
175 |
|
|
|
335 |
facet_move_top_mask = move_top(mask)
|
336 |
facet_move_left_mask = move_left(mask)
|
337 |
facet_move_top_left_mask = move_top_left(mask)
|
338 |
+
facet_top_left_mask = (
|
339 |
+
facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask * mask
|
340 |
+
)
|
341 |
facet_top_right_mask = move_right(facet_top_left_mask)
|
342 |
facet_bottom_left_mask = move_bottom(facet_top_left_mask)
|
343 |
facet_bottom_right_mask = move_bottom_right(facet_top_left_mask)
|
344 |
|
345 |
+
return cp.hstack(
|
346 |
+
(
|
347 |
+
4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)),
|
348 |
+
idx[facet_top_left_mask][:, None],
|
349 |
+
idx[facet_bottom_left_mask][:, None],
|
350 |
+
idx[facet_bottom_right_mask][:, None],
|
351 |
+
idx[facet_top_right_mask][:, None],
|
352 |
+
)
|
353 |
+
).astype(int)
|
354 |
|
355 |
|
356 |
def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1):
|
|
|
374 |
u[..., 0] = xx
|
375 |
u[..., 1] = yy
|
376 |
u[..., 2] = 1
|
377 |
+
u = u[mask].T # 3 x m
|
378 |
+
vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3
|
379 |
|
380 |
return vertices
|
381 |
|
|
|
384 |
return 1 / (1 + cp.exp(-k * x))
|
385 |
|
386 |
|
|
|
387 |
def boundary_excluded_mask(mask):
|
388 |
top_mask = cp.pad(mask, ((1, 0), (0, 0)), "constant", constant_values=0)[:-1, :]
|
389 |
bottom_mask = cp.pad(mask, ((0, 1), (0, 0)), "constant", constant_values=0)[1:, :]
|
|
|
419 |
return B, B_full
|
420 |
|
421 |
|
422 |
+
def double_side_bilateral_normal_integration(
|
423 |
+
normal_front,
|
424 |
+
normal_back,
|
425 |
+
normal_mask,
|
426 |
+
depth_front=None,
|
427 |
+
depth_back=None,
|
428 |
+
depth_mask=None,
|
429 |
+
k=2,
|
430 |
+
lambda_normal_back=1,
|
431 |
+
lambda_depth_front=1e-4,
|
432 |
+
lambda_depth_back=1e-2,
|
433 |
+
lambda_boundary_consistency=1,
|
434 |
+
step_size=1,
|
435 |
+
max_iter=150,
|
436 |
+
tol=1e-4,
|
437 |
+
cg_max_iter=5000,
|
438 |
+
cg_tol=1e-3
|
439 |
+
):
|
440 |
|
441 |
# To avoid confusion, we list the coordinate systems in this code as follows
|
442 |
#
|
|
|
478 |
del normal_map_back
|
479 |
|
480 |
# right, left, top, bottom
|
481 |
+
A3_f, A4_f, A1_f, A2_f = generate_dx_dy(
|
482 |
+
normal_mask, nz_horizontal=nz_front, nz_vertical=nz_front, step_size=step_size
|
483 |
+
)
|
484 |
+
A3_b, A4_b, A1_b, A2_b = generate_dx_dy(
|
485 |
+
normal_mask, nz_horizontal=nz_back, nz_vertical=nz_back, step_size=step_size
|
486 |
+
)
|
|
|
|
|
487 |
|
488 |
has_left_mask = cp.logical_and(move_right(normal_mask), normal_mask)
|
489 |
has_right_mask = cp.logical_and(move_left(normal_mask), normal_mask)
|
|
|
507 |
b_back = cp.concatenate((-nx_back, -nx_back, -ny_back, -ny_back))
|
508 |
|
509 |
# initialization
|
510 |
+
W_front = spdiags(
|
511 |
+
0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr"
|
512 |
+
)
|
513 |
+
W_back = spdiags(
|
514 |
+
0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr"
|
515 |
+
)
|
|
|
|
|
|
|
|
|
516 |
|
517 |
z_front = cp.zeros(num_normals, float)
|
518 |
z_back = cp.zeros(num_normals, float)
|
519 |
z_combined = cp.concatenate((z_front, z_back))
|
520 |
|
521 |
B, B_full = create_boundary_matrix(normal_mask)
|
522 |
+
B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug
|
523 |
|
524 |
energy_list = []
|
525 |
|
526 |
if depth_mask is not None:
|
527 |
+
depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,)
|
528 |
+
z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,)
|
529 |
z_prior_front[~depth_mask_flat] = 0
|
530 |
z_prior_back = depth_map_back[normal_mask]
|
531 |
z_prior_back[~depth_mask_flat] = 0
|
|
|
559 |
vstack((csr_matrix((num_normals, num_normals)), A_mat_back))]) + B_mat
|
560 |
b_vec_combined = cp.concatenate((b_vec_front, b_vec_back))
|
561 |
|
562 |
+
D = spdiags(
|
563 |
+
1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals, 2 * num_normals,
|
564 |
+
"csr"
|
565 |
+
) # Jacob preconditioner
|
566 |
|
567 |
+
z_combined, _ = cg(
|
568 |
+
A_mat_combined, b_vec_combined, M=D, x0=z_combined, maxiter=cg_max_iter, tol=cg_tol
|
569 |
+
)
|
|
|
|
|
|
|
570 |
z_front = z_combined[:num_normals]
|
571 |
z_back = z_combined[num_normals:]
|
572 |
+
wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top
|
573 |
+
wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right
|
574 |
wu_f[top_boundnary_mask] = 0.5
|
575 |
wu_f[bottom_boundary_mask] = 0.5
|
576 |
wv_f[left_boundary_mask] = 0.5
|
577 |
wv_f[right_boudnary_mask] = 0.5
|
578 |
+
W_front = spdiags(
|
579 |
+
cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)),
|
580 |
+
0,
|
581 |
+
4 * num_normals,
|
582 |
+
4 * num_normals,
|
583 |
+
format="csr"
|
584 |
+
)
|
585 |
+
|
586 |
+
wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top
|
587 |
+
wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right
|
588 |
wu_b[top_boundnary_mask] = 0.5
|
589 |
wu_b[bottom_boundary_mask] = 0.5
|
590 |
wv_b[left_boundary_mask] = 0.5
|
591 |
wv_b[right_boudnary_mask] = 0.5
|
592 |
+
W_back = spdiags(
|
593 |
+
cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)),
|
594 |
+
0,
|
595 |
+
4 * num_normals,
|
596 |
+
4 * num_normals,
|
597 |
+
format="csr"
|
598 |
+
)
|
599 |
|
600 |
energy_old = energy
|
601 |
energy = (A_front_data @ z_front - b_front).T @ W_front @ (A_front_data @ z_front - b_front) + \
|
|
|
611 |
if relative_energy < tol:
|
612 |
break
|
613 |
# del A1, A2, A3, A4, nx, ny
|
614 |
+
|
615 |
depth_map_front_est = cp.ones_like(normal_mask, float) * cp.nan
|
616 |
depth_map_front_est[normal_mask] = z_front
|
617 |
|
618 |
depth_map_back_est = cp.ones_like(normal_mask, float) * cp.nan
|
619 |
depth_map_back_est[normal_mask] = z_back
|
620 |
+
|
621 |
# manually cut the intersection
|
622 |
+
normal_mask[depth_map_front_est >= depth_map_back_est] = False
|
623 |
depth_map_front_est[~normal_mask] = cp.nan
|
624 |
depth_map_back_est[~normal_mask] = cp.nan
|
625 |
|
626 |
vertices_front = cp.asnumpy(
|
627 |
+
map_depth_map_to_point_clouds(
|
628 |
+
depth_map_front_est, normal_mask, K=None, step_size=step_size
|
629 |
+
)
|
630 |
+
)
|
631 |
vertices_back = cp.asnumpy(
|
632 |
+
map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size)
|
633 |
+
)
|
634 |
|
635 |
facets_back = cp.asnumpy(construct_facets_from(normal_mask))
|
636 |
|
|
|
667 |
depth_B_arr = depth2arr(in_tensor["depth_B"][idx])
|
668 |
|
669 |
BNI_dict = {}
|
670 |
+
|
671 |
# clothed human
|
672 |
BNI_dict["normal_F"] = normal_F_arr
|
673 |
BNI_dict["normal_B"] = normal_B_arr
|
lib/common/blender_utils.py
CHANGED
@@ -3,6 +3,7 @@ import sys, os
|
|
3 |
from math import radians
|
4 |
import mathutils
|
5 |
import bmesh
|
|
|
6 |
print(sys.exec_prefix)
|
7 |
from tqdm import tqdm
|
8 |
import numpy as np
|
@@ -29,7 +30,6 @@ shadows = False
|
|
29 |
# diffuse_color = (18/255., 139/255., 142/255.,1) #correct
|
30 |
# diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
|
31 |
|
32 |
-
|
33 |
smooth = False
|
34 |
|
35 |
wireframe = False
|
@@ -47,13 +47,16 @@ compositor_alpha = 0.7
|
|
47 |
# Helper functions
|
48 |
##################################################
|
49 |
|
|
|
50 |
def blender_print(*args, **kwargs):
|
51 |
-
print
|
|
|
52 |
|
53 |
def using_app():
|
54 |
''' Returns if script is running through Blender application (GUI or background processing)'''
|
55 |
return (not sys.argv[0].endswith('.py'))
|
56 |
|
|
|
57 |
def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
|
58 |
''' Sets up diffuse/transparent material with backface culling in cycles'''
|
59 |
|
@@ -110,8 +113,10 @@ def setup_diffuse_transparent_material(target, color, object_transparent, backfa
|
|
110 |
links.new(node_mix_backface.outputs[0], node_output.inputs[0])
|
111 |
return
|
112 |
|
|
|
113 |
##################################################
|
114 |
|
|
|
115 |
def setup_scene():
|
116 |
global render
|
117 |
global cycles_gpu
|
@@ -150,12 +155,13 @@ def setup_scene():
|
|
150 |
if cycles_gpu:
|
151 |
print('Activating GPU acceleration')
|
152 |
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
153 |
-
|
154 |
if bpy.app.version[0] >= 3:
|
155 |
-
cuda_devices = bpy.context.preferences.addons[
|
|
|
156 |
else:
|
157 |
-
(cuda_devices, opencl_devices
|
158 |
-
|
159 |
|
160 |
if (len(cuda_devices) < 1):
|
161 |
print('ERROR: CUDA GPU acceleration not available')
|
@@ -178,7 +184,7 @@ def setup_scene():
|
|
178 |
if bpy.app.version[0] < 3:
|
179 |
scene.render.tile_x = 64
|
180 |
scene.render.tile_y = 64
|
181 |
-
|
182 |
# Disable Blender 3 denoiser to properly measure Cycles render speed
|
183 |
if bpy.app.version[0] >= 3:
|
184 |
scene.cycles.use_denoising = False
|
@@ -226,7 +232,6 @@ def setup_scene():
|
|
226 |
bpy.ops.mesh.mark_freestyle_edge(clear=True)
|
227 |
bpy.ops.object.mode_set(mode='OBJECT')
|
228 |
|
229 |
-
|
230 |
# Setup freestyle mode for wireframe overlay rendering
|
231 |
if wireframe:
|
232 |
scene.render.use_freestyle = True
|
@@ -245,8 +250,10 @@ def setup_scene():
|
|
245 |
# Output transparent image when no background is used
|
246 |
scene.render.image_settings.color_mode = 'RGBA'
|
247 |
|
|
|
248 |
##################################################
|
249 |
|
|
|
250 |
def setup_compositing():
|
251 |
|
252 |
global compositor_image_scale
|
@@ -275,6 +282,7 @@ def setup_compositing():
|
|
275 |
|
276 |
links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
|
277 |
|
|
|
278 |
def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
|
279 |
'''Render image of given model file'''
|
280 |
global smooth
|
@@ -288,13 +296,13 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
|
|
288 |
# Import object into scene
|
289 |
bpy.ops.import_scene.obj(filepath=path)
|
290 |
object = bpy.context.selected_objects[0]
|
291 |
-
|
292 |
object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
|
293 |
-
z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:,1])
|
294 |
# z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
|
295 |
# blender_print(radians(90.0), z_bottom, z_top)
|
296 |
object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
|
297 |
-
|
298 |
if quads:
|
299 |
bpy.context.view_layer.objects.active = object
|
300 |
bpy.ops.object.mode_set(mode='EDIT')
|
@@ -309,11 +317,11 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
|
|
309 |
bpy.ops.object.mode_set(mode='EDIT')
|
310 |
bpy.ops.mesh.mark_freestyle_edge(clear=False)
|
311 |
bpy.ops.object.mode_set(mode='OBJECT')
|
312 |
-
|
313 |
if correct:
|
314 |
-
diffuse_color = (18/255., 139/255., 142/255.,1)
|
315 |
else:
|
316 |
-
diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
|
317 |
|
318 |
setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
|
319 |
|
@@ -336,10 +344,10 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
|
|
336 |
bpy.ops.render.render(write_still=True)
|
337 |
|
338 |
# Remove temporary output redirection
|
339 |
-
# sys.stdout.flush()
|
340 |
-
# os.close(1)
|
341 |
-
# os.dup(old)
|
342 |
-
# os.close(old)
|
343 |
|
344 |
# Delete last selected object from scene
|
345 |
object.select_set(True)
|
@@ -351,7 +359,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
|
|
351 |
global quality_preview
|
352 |
|
353 |
if not input_file.endswith('.obj'):
|
354 |
-
print('ERROR: Invalid input: ' + input_file
|
355 |
return
|
356 |
|
357 |
print('Processing: ' + input_file)
|
@@ -361,7 +369,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
|
|
361 |
if quality_preview:
|
362 |
output_file = output_file.replace('.png', '-preview.png')
|
363 |
|
364 |
-
angle = 360.0/views
|
365 |
pbar = tqdm(range(0, views))
|
366 |
for view in pbar:
|
367 |
pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
|
@@ -369,8 +377,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True):
|
|
369 |
output_file_view = f"{output_file}/{view:03d}.png"
|
370 |
if not os.path.exists(os.path.join(output_dir, output_file_view)):
|
371 |
render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
|
372 |
-
|
373 |
cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
|
374 |
" -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
|
375 |
os.system(cmd)
|
376 |
-
|
|
|
3 |
from math import radians
|
4 |
import mathutils
|
5 |
import bmesh
|
6 |
+
|
7 |
print(sys.exec_prefix)
|
8 |
from tqdm import tqdm
|
9 |
import numpy as np
|
|
|
30 |
# diffuse_color = (18/255., 139/255., 142/255.,1) #correct
|
31 |
# diffuse_color = (251/255., 60/255., 60/255.,1) #wrong
|
32 |
|
|
|
33 |
smooth = False
|
34 |
|
35 |
wireframe = False
|
|
|
47 |
# Helper functions
|
48 |
##################################################
|
49 |
|
50 |
+
|
51 |
def blender_print(*args, **kwargs):
|
52 |
+
print(*args, **kwargs, file=sys.stderr)
|
53 |
+
|
54 |
|
55 |
def using_app():
|
56 |
''' Returns if script is running through Blender application (GUI or background processing)'''
|
57 |
return (not sys.argv[0].endswith('.py'))
|
58 |
|
59 |
+
|
60 |
def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent):
|
61 |
''' Sets up diffuse/transparent material with backface culling in cycles'''
|
62 |
|
|
|
113 |
links.new(node_mix_backface.outputs[0], node_output.inputs[0])
|
114 |
return
|
115 |
|
116 |
+
|
117 |
##################################################
|
118 |
|
119 |
+
|
120 |
def setup_scene():
|
121 |
global render
|
122 |
global cycles_gpu
|
|
|
155 |
if cycles_gpu:
|
156 |
print('Activating GPU acceleration')
|
157 |
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
158 |
+
|
159 |
if bpy.app.version[0] >= 3:
|
160 |
+
cuda_devices = bpy.context.preferences.addons[
|
161 |
+
'cycles'].preferences.get_devices_for_type(compute_device_type='CUDA')
|
162 |
else:
|
163 |
+
(cuda_devices, opencl_devices
|
164 |
+
) = bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
165 |
|
166 |
if (len(cuda_devices) < 1):
|
167 |
print('ERROR: CUDA GPU acceleration not available')
|
|
|
184 |
if bpy.app.version[0] < 3:
|
185 |
scene.render.tile_x = 64
|
186 |
scene.render.tile_y = 64
|
187 |
+
|
188 |
# Disable Blender 3 denoiser to properly measure Cycles render speed
|
189 |
if bpy.app.version[0] >= 3:
|
190 |
scene.cycles.use_denoising = False
|
|
|
232 |
bpy.ops.mesh.mark_freestyle_edge(clear=True)
|
233 |
bpy.ops.object.mode_set(mode='OBJECT')
|
234 |
|
|
|
235 |
# Setup freestyle mode for wireframe overlay rendering
|
236 |
if wireframe:
|
237 |
scene.render.use_freestyle = True
|
|
|
250 |
# Output transparent image when no background is used
|
251 |
scene.render.image_settings.color_mode = 'RGBA'
|
252 |
|
253 |
+
|
254 |
##################################################
|
255 |
|
256 |
+
|
257 |
def setup_compositing():
|
258 |
|
259 |
global compositor_image_scale
|
|
|
282 |
|
283 |
links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0])
|
284 |
|
285 |
+
|
286 |
def render_file(input_file, input_dir, output_file, output_dir, yaw, correct):
|
287 |
'''Render image of given model file'''
|
288 |
global smooth
|
|
|
296 |
# Import object into scene
|
297 |
bpy.ops.import_scene.obj(filepath=path)
|
298 |
object = bpy.context.selected_objects[0]
|
299 |
+
|
300 |
object.rotation_euler = (radians(90.0), 0.0, radians(yaw))
|
301 |
+
z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:, 1])
|
302 |
# z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1])
|
303 |
# blender_print(radians(90.0), z_bottom, z_top)
|
304 |
object.location -= mathutils.Vector((0.0, 0.0, z_bottom))
|
305 |
+
|
306 |
if quads:
|
307 |
bpy.context.view_layer.objects.active = object
|
308 |
bpy.ops.object.mode_set(mode='EDIT')
|
|
|
317 |
bpy.ops.object.mode_set(mode='EDIT')
|
318 |
bpy.ops.mesh.mark_freestyle_edge(clear=False)
|
319 |
bpy.ops.object.mode_set(mode='OBJECT')
|
320 |
+
|
321 |
if correct:
|
322 |
+
diffuse_color = (18 / 255., 139 / 255., 142 / 255., 1) #correct
|
323 |
else:
|
324 |
+
diffuse_color = (251 / 255., 60 / 255., 60 / 255., 1) #wrong
|
325 |
|
326 |
setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent)
|
327 |
|
|
|
344 |
bpy.ops.render.render(write_still=True)
|
345 |
|
346 |
# Remove temporary output redirection
|
347 |
+
# sys.stdout.flush()
|
348 |
+
# os.close(1)
|
349 |
+
# os.dup(old)
|
350 |
+
# os.close(old)
|
351 |
|
352 |
# Delete last selected object from scene
|
353 |
object.select_set(True)
|
|
|
359 |
global quality_preview
|
360 |
|
361 |
if not input_file.endswith('.obj'):
|
362 |
+
print('ERROR: Invalid input: ' + input_file)
|
363 |
return
|
364 |
|
365 |
print('Processing: ' + input_file)
|
|
|
369 |
if quality_preview:
|
370 |
output_file = output_file.replace('.png', '-preview.png')
|
371 |
|
372 |
+
angle = 360.0 / views
|
373 |
pbar = tqdm(range(0, views))
|
374 |
for view in pbar:
|
375 |
pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}")
|
|
|
377 |
output_file_view = f"{output_file}/{view:03d}.png"
|
378 |
if not os.path.exists(os.path.join(output_dir, output_file_view)):
|
379 |
render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct)
|
380 |
+
|
381 |
cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \
|
382 |
" -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4"
|
383 |
os.system(cmd)
|
|
lib/common/cloth_extraction.py
CHANGED
@@ -36,11 +36,13 @@ def load_segmentation(path, shape):
|
|
36 |
xy = np.vstack((x, y)).T
|
37 |
coordinates.append(xy)
|
38 |
|
39 |
-
segmentations.append(
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
return segmentations
|
46 |
|
@@ -56,9 +58,8 @@ def smpl_to_recon_labels(recon, smpl, k=1):
|
|
56 |
Returns a dictionary containing the bodypart and the corresponding indices
|
57 |
"""
|
58 |
smpl_vert_segmentation = json.load(
|
59 |
-
open(
|
60 |
-
|
61 |
-
"smpl_vert_segmentation.json")))
|
62 |
n = smpl.vertices.shape[0]
|
63 |
y = np.array([None] * n)
|
64 |
for key, val in smpl_vert_segmentation.items():
|
@@ -71,8 +72,7 @@ def smpl_to_recon_labels(recon, smpl, k=1):
|
|
71 |
|
72 |
recon_labels = {}
|
73 |
for key in smpl_vert_segmentation.keys():
|
74 |
-
recon_labels[key] = list(
|
75 |
-
np.argwhere(y_pred == key).flatten().astype(int))
|
76 |
|
77 |
return recon_labels
|
78 |
|
@@ -139,8 +139,7 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
|
|
139 |
if type == 1 or type == 3 or type == 10:
|
140 |
body_parts_to_remove += ["leftForeArm", "rightForeArm"]
|
141 |
# No sleeves at all or lower body clothes
|
142 |
-
elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8
|
143 |
-
or type == 9):
|
144 |
body_parts_to_remove += [
|
145 |
"leftForeArm",
|
146 |
"rightForeArm",
|
@@ -159,8 +158,8 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
|
|
159 |
]
|
160 |
|
161 |
verts_to_remove = list(
|
162 |
-
itertools.chain.from_iterable(
|
163 |
-
|
164 |
|
165 |
label_mask = np.zeros(num_verts, dtype=bool)
|
166 |
label_mask[verts_to_remove] = True
|
|
|
36 |
xy = np.vstack((x, y)).T
|
37 |
coordinates.append(xy)
|
38 |
|
39 |
+
segmentations.append(
|
40 |
+
{
|
41 |
+
"type": val["category_name"],
|
42 |
+
"type_id": val["category_id"],
|
43 |
+
"coordinates": coordinates,
|
44 |
+
}
|
45 |
+
)
|
46 |
|
47 |
return segmentations
|
48 |
|
|
|
58 |
Returns a dictionary containing the bodypart and the corresponding indices
|
59 |
"""
|
60 |
smpl_vert_segmentation = json.load(
|
61 |
+
open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json"))
|
62 |
+
)
|
|
|
63 |
n = smpl.vertices.shape[0]
|
64 |
y = np.array([None] * n)
|
65 |
for key, val in smpl_vert_segmentation.items():
|
|
|
72 |
|
73 |
recon_labels = {}
|
74 |
for key in smpl_vert_segmentation.keys():
|
75 |
+
recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int))
|
|
|
76 |
|
77 |
return recon_labels
|
78 |
|
|
|
139 |
if type == 1 or type == 3 or type == 10:
|
140 |
body_parts_to_remove += ["leftForeArm", "rightForeArm"]
|
141 |
# No sleeves at all or lower body clothes
|
142 |
+
elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9):
|
|
|
143 |
body_parts_to_remove += [
|
144 |
"leftForeArm",
|
145 |
"rightForeArm",
|
|
|
158 |
]
|
159 |
|
160 |
verts_to_remove = list(
|
161 |
+
itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove])
|
162 |
+
)
|
163 |
|
164 |
label_mask = np.zeros(num_verts, dtype=bool)
|
165 |
label_mask[verts_to_remove] = True
|
lib/common/config.py
CHANGED
@@ -100,6 +100,7 @@ _C.bni.thickness = 0.00
|
|
100 |
_C.bni.hand_thres = 4e-2
|
101 |
_C.bni.face_thres = 6e-2
|
102 |
_C.bni.hps_type = "pixie"
|
|
|
103 |
|
104 |
# kernel_size, stride, dilation, padding
|
105 |
|
@@ -170,10 +171,10 @@ _C.dataset.rp_type = "pifu900"
|
|
170 |
_C.dataset.th_type = "train"
|
171 |
_C.dataset.input_size = 512
|
172 |
_C.dataset.rotation_num = 3
|
173 |
-
_C.dataset.num_precomp = 10
|
174 |
-
_C.dataset.num_multiseg = 500
|
175 |
-
_C.dataset.num_knn = 10
|
176 |
-
_C.dataset.num_knn_dis = 20
|
177 |
_C.dataset.num_verts_max = 20000
|
178 |
_C.dataset.zray_type = False
|
179 |
_C.dataset.online_smpl = False
|
@@ -210,8 +211,7 @@ def get_cfg_defaults():
|
|
210 |
|
211 |
# Alternatively, provide a way to import the defaults as
|
212 |
# a global singleton:
|
213 |
-
cfg = _C
|
214 |
-
|
215 |
|
216 |
# cfg = get_cfg_defaults()
|
217 |
# cfg.merge_from_file('./configs/example.yaml')
|
@@ -244,9 +244,7 @@ def parse_args(args):
|
|
244 |
def parse_args_extend(args):
|
245 |
if args.resume:
|
246 |
if not os.path.exists(args.log_dir):
|
247 |
-
raise ValueError(
|
248 |
-
"Experiment are set to resume mode, but log directory does not exist."
|
249 |
-
)
|
250 |
|
251 |
# load log's cfg
|
252 |
cfg_file = os.path.join(args.log_dir, "cfg.yaml")
|
|
|
100 |
_C.bni.hand_thres = 4e-2
|
101 |
_C.bni.face_thres = 6e-2
|
102 |
_C.bni.hps_type = "pixie"
|
103 |
+
_C.bni.texture_src = "image"
|
104 |
|
105 |
# kernel_size, stride, dilation, padding
|
106 |
|
|
|
171 |
_C.dataset.th_type = "train"
|
172 |
_C.dataset.input_size = 512
|
173 |
_C.dataset.rotation_num = 3
|
174 |
+
_C.dataset.num_precomp = 10 # Number of segmentation classifiers
|
175 |
+
_C.dataset.num_multiseg = 500 # Number of categories per classifier
|
176 |
+
_C.dataset.num_knn = 10 # for loss/error
|
177 |
+
_C.dataset.num_knn_dis = 20 # for accuracy
|
178 |
_C.dataset.num_verts_max = 20000
|
179 |
_C.dataset.zray_type = False
|
180 |
_C.dataset.online_smpl = False
|
|
|
211 |
|
212 |
# Alternatively, provide a way to import the defaults as
|
213 |
# a global singleton:
|
214 |
+
cfg = _C # users can `from config import cfg`
|
|
|
215 |
|
216 |
# cfg = get_cfg_defaults()
|
217 |
# cfg.merge_from_file('./configs/example.yaml')
|
|
|
244 |
def parse_args_extend(args):
|
245 |
if args.resume:
|
246 |
if not os.path.exists(args.log_dir):
|
247 |
+
raise ValueError("Experiment are set to resume mode, but log directory does not exist.")
|
|
|
|
|
248 |
|
249 |
# load log's cfg
|
250 |
cfg_file = os.path.join(args.log_dir, "cfg.yaml")
|
lib/common/imutils.py
CHANGED
@@ -3,14 +3,13 @@ import mediapipe as mp
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import torch.nn.functional as F
|
6 |
-
from rembg import remove
|
7 |
-
from rembg.session_factory import new_session
|
8 |
from PIL import Image
|
9 |
-
from torchvision.models import detection
|
10 |
-
|
11 |
from lib.pymafx.core import constants
|
12 |
-
|
|
|
|
|
13 |
from torchvision import transforms
|
|
|
14 |
|
15 |
|
16 |
def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
|
@@ -24,42 +23,40 @@ def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
|
|
24 |
return transforms.Compose(all_ops)
|
25 |
|
26 |
|
27 |
-
def
|
28 |
-
dx = (w2 - w1) / 2.0
|
29 |
-
dy = (h2 - h1) / 2.0
|
30 |
-
|
31 |
-
matrix_trans = np.array([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0]])
|
32 |
-
|
33 |
-
scale = np.min([float(w2) / w1, float(h2) / h1])
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
M =
|
39 |
|
40 |
return M
|
41 |
|
42 |
|
43 |
-
def
|
44 |
-
cx, cy = center
|
45 |
-
tx, ty = translate
|
46 |
-
|
47 |
-
M = [1, 0, 0, 0, 1, 0]
|
48 |
-
M = [x * scale for x in M]
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
# Apply center translation: T * C * RSS * C^-1
|
55 |
-
M[2] += cx + tx
|
56 |
-
M[5] += cy + ty
|
57 |
return M
|
58 |
|
59 |
|
60 |
def load_img(img_file):
|
61 |
|
62 |
img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
|
|
|
|
|
|
63 |
if len(img.shape) == 2:
|
64 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
65 |
|
@@ -68,11 +65,10 @@ def load_img(img_file):
|
|
68 |
else:
|
69 |
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
|
70 |
|
71 |
-
return img
|
72 |
|
73 |
|
74 |
def get_keypoints(image):
|
75 |
-
|
76 |
def collect_xyv(x, body=True):
|
77 |
lmk = x.landmark
|
78 |
all_lmks = []
|
@@ -84,8 +80,8 @@ def get_keypoints(image):
|
|
84 |
mp_holistic = mp.solutions.holistic
|
85 |
|
86 |
with mp_holistic.Holistic(
|
87 |
-
|
88 |
-
|
89 |
) as holistic:
|
90 |
results = holistic.process(image)
|
91 |
|
@@ -93,9 +89,15 @@ def get_keypoints(image):
|
|
93 |
|
94 |
result = {}
|
95 |
result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
|
96 |
-
result["lhand"] = collect_xyv(
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
return result
|
101 |
|
@@ -104,13 +106,21 @@ def get_pymafx(image, landmarks):
|
|
104 |
|
105 |
# image [3,512,512]
|
106 |
|
107 |
-
item = {
|
|
|
|
|
|
|
108 |
|
109 |
for part in ['lhand', 'rhand', 'face']:
|
110 |
kp2d = landmarks[part]
|
111 |
kp2d_valid = kp2d[kp2d[:, 3] > 0.]
|
112 |
if len(kp2d_valid) > 0:
|
113 |
-
bbox = [
|
|
|
|
|
|
|
|
|
|
|
114 |
center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
|
115 |
scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
|
116 |
|
@@ -141,20 +151,6 @@ def get_pymafx(image, landmarks):
|
|
141 |
return item
|
142 |
|
143 |
|
144 |
-
def expand_bbox(bbox, width, height, ratio=0.1):
|
145 |
-
|
146 |
-
bbox = np.around(bbox).astype(np.int16)
|
147 |
-
bbox_width = bbox[2] - bbox[0]
|
148 |
-
bbox_height = bbox[3] - bbox[1]
|
149 |
-
|
150 |
-
bbox[1] = max(bbox[1] - bbox_height * ratio, 0)
|
151 |
-
bbox[3] = min(bbox[3] + bbox_height * ratio, height)
|
152 |
-
bbox[0] = max(bbox[0] - bbox_width * ratio, 0)
|
153 |
-
bbox[2] = min(bbox[2] + bbox_width * ratio, width)
|
154 |
-
|
155 |
-
return bbox
|
156 |
-
|
157 |
-
|
158 |
def remove_floats(mask):
|
159 |
|
160 |
# 1. find all the contours
|
@@ -173,51 +169,48 @@ def remove_floats(mask):
|
|
173 |
return new_mask
|
174 |
|
175 |
|
176 |
-
def process_image(img_file, hps_type, single, input_res
|
177 |
|
178 |
-
img_raw = load_img(img_file)
|
179 |
-
|
180 |
-
in_height,
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
185 |
|
186 |
# detection for bbox
|
187 |
-
|
188 |
-
detector.eval()
|
189 |
-
predictions = detector([torch.from_numpy(img_square).permute(2, 0, 1) / 255.])[0]
|
190 |
|
191 |
if single:
|
192 |
top_score = predictions["scores"][predictions["labels"] == 1].max()
|
193 |
human_ids = torch.where(predictions["scores"] == top_score)[0]
|
194 |
else:
|
195 |
-
human_ids = torch.logical_and(predictions["labels"] == 1,
|
|
|
196 |
|
197 |
boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
|
198 |
masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
|
199 |
|
200 |
-
|
201 |
-
height = boxes[:, 3] - boxes[:, 1] #(N,)
|
202 |
-
center = np.array([(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]).T #(N,2)
|
203 |
-
scale = np.array([width, height]).max(axis=0) / 90.
|
204 |
|
205 |
img_icon_lst = []
|
206 |
img_crop_lst = []
|
207 |
img_hps_lst = []
|
208 |
img_mask_lst = []
|
209 |
-
uncrop_param_lst = []
|
210 |
landmark_lst = []
|
211 |
hands_visibility_lst = []
|
212 |
img_pymafx_lst = []
|
213 |
|
214 |
uncrop_param = {
|
215 |
-
"center": center,
|
216 |
-
"scale": scale,
|
217 |
"ori_shape": [in_height, in_width],
|
218 |
"box_shape": [input_res, input_res],
|
219 |
-
"
|
220 |
-
"
|
|
|
221 |
}
|
222 |
|
223 |
for idx in range(len(boxes)):
|
@@ -228,59 +221,74 @@ def process_image(img_file, hps_type, single, input_res=512):
|
|
228 |
else:
|
229 |
mask_detection = masks[0] * 0.
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net"))
|
236 |
img_mask = remove_floats(img_rembg[:, :, [3]])
|
237 |
|
238 |
-
# required image tensors / arrays
|
239 |
-
|
240 |
-
# img_icon (tensor): (-1, 1), [3,512,512]
|
241 |
-
# img_hps (tensor): (-2.11, 2.44), [3,224,224]
|
242 |
-
|
243 |
-
# img_np (array): (0, 255), [512,512,3]
|
244 |
-
# img_rembg (array): (0, 255), [512,512,4]
|
245 |
-
# img_mask (array): (0, 1), [512,512,1]
|
246 |
-
# img_crop (array): (0, 255), [512,512,4]
|
247 |
-
|
248 |
mean_icon = std_icon = (0.5, 0.5, 0.5)
|
249 |
img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
|
250 |
-
img_icon = transform_to_tensor(512, mean_icon, std_icon)(
|
251 |
-
|
252 |
-
|
|
|
|
|
253 |
|
254 |
landmarks = get_keypoints(img_np)
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
if hps_type == 'pymafx':
|
257 |
img_pymafx_lst.append(
|
258 |
get_pymafx(
|
259 |
-
transform_to_tensor(512, constants.IMG_NORM_MEAN,
|
260 |
-
|
|
|
|
|
261 |
|
262 |
img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
|
263 |
img_icon_lst.append(img_icon)
|
264 |
img_hps_lst.append(img_hps)
|
265 |
img_mask_lst.append(torch.tensor(img_mask[..., 0]))
|
266 |
-
uncrop_param_lst.append(uncrop_param)
|
267 |
landmark_lst.append(landmarks['body'])
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
275 |
|
276 |
return_dict = {
|
277 |
-
"img_icon": torch.stack(img_icon_lst).float(),
|
278 |
-
"img_crop": torch.stack(img_crop_lst).float(),
|
279 |
-
"img_hps": torch.stack(img_hps_lst).float(),
|
280 |
-
"img_raw": img_raw,
|
281 |
-
"img_mask": torch.stack(img_mask_lst).float(),
|
282 |
"uncrop_param": uncrop_param,
|
283 |
-
"landmark": torch.stack(landmark_lst),
|
284 |
"hands_visibility": hands_visibility_lst,
|
285 |
}
|
286 |
|
@@ -302,250 +310,51 @@ def process_image(img_file, hps_type, single, input_res=512):
|
|
302 |
return return_dict
|
303 |
|
304 |
|
305 |
-
def
|
306 |
-
"""Generate transformation matrix."""
|
307 |
-
h = 100 * scale
|
308 |
-
t = np.zeros((3, 3))
|
309 |
-
t[0, 0] = float(res[1]) / h
|
310 |
-
t[1, 1] = float(res[0]) / h
|
311 |
-
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
|
312 |
-
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
|
313 |
-
t[2, 2] = 1
|
314 |
-
|
315 |
-
return t
|
316 |
-
|
317 |
-
|
318 |
-
def transform(pt, center, scale, res, invert=0):
|
319 |
-
"""Transform pixel location to different reference."""
|
320 |
-
t = get_transform(center, scale, res)
|
321 |
-
if invert:
|
322 |
-
t = np.linalg.inv(t)
|
323 |
-
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
|
324 |
-
new_pt = np.dot(t, new_pt)
|
325 |
-
return np.around(new_pt[:2]).astype(np.int16)
|
326 |
-
|
327 |
-
|
328 |
-
def crop(img, center, scale, res):
|
329 |
-
"""Crop image according to the supplied bounding box."""
|
330 |
-
|
331 |
-
img_height, img_width = img.shape[:2]
|
332 |
-
|
333 |
-
# Upper left point
|
334 |
-
ul = np.array(transform([0, 0], center, scale, res, invert=1))
|
335 |
-
|
336 |
-
# Bottom right point
|
337 |
-
br = np.array(transform(res, center, scale, res, invert=1))
|
338 |
-
|
339 |
-
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
340 |
-
if len(img.shape) > 2:
|
341 |
-
new_shape += [img.shape[2]]
|
342 |
-
new_img = np.zeros(new_shape)
|
343 |
-
|
344 |
-
# Range to fill new array
|
345 |
-
new_x = max(0, -ul[0]), min(br[0], img_width) - ul[0]
|
346 |
-
new_y = max(0, -ul[1]), min(br[1], img_height) - ul[1]
|
347 |
-
|
348 |
-
# Range to sample from original image
|
349 |
-
old_x = max(0, ul[0]), min(img_width, br[0])
|
350 |
-
old_y = max(0, ul[1]), min(img_height, br[1])
|
351 |
-
|
352 |
-
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
|
353 |
-
new_img = F.interpolate(
|
354 |
-
torch.tensor(new_img).permute(2, 0, 1).unsqueeze(0), res, mode='bilinear').permute(0, 2, 3,
|
355 |
-
1)[0].numpy().astype(np.uint8)
|
356 |
-
|
357 |
-
return new_img, (old_x, new_x, old_y, new_y, new_shape)
|
358 |
-
|
359 |
-
|
360 |
-
def crop_segmentation(org_coord, res, cropping_parameters):
|
361 |
-
old_x, new_x, old_y, new_y, new_shape = cropping_parameters
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
def corner_align(ul, br):
|
374 |
-
|
375 |
-
if ul[1] - ul[0] != br[1] - br[0]:
|
376 |
-
ul[1] = ul[0] + br[1] - br[0]
|
377 |
-
|
378 |
-
return ul, br
|
379 |
-
|
380 |
-
|
381 |
-
def uncrop(img, center, scale, orig_shape):
|
382 |
-
"""'Undo' the image cropping/resizing.
|
383 |
-
This function is used when evaluating mask/part segmentation.
|
384 |
-
"""
|
385 |
-
|
386 |
-
res = img.shape[:2]
|
387 |
-
|
388 |
-
# Upper left point
|
389 |
-
ul = np.array(transform([0, 0], center, scale, res, invert=1))
|
390 |
-
# Bottom right point
|
391 |
-
br = np.array(transform(res, center, scale, res, invert=1))
|
392 |
-
|
393 |
-
# quick fix
|
394 |
-
ul, br = corner_align(ul, br)
|
395 |
-
|
396 |
-
# size of cropped image
|
397 |
-
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
398 |
-
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
399 |
-
|
400 |
-
# Range to fill new array
|
401 |
-
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
402 |
-
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
403 |
|
404 |
-
|
405 |
-
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
406 |
-
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
407 |
|
408 |
-
|
|
|
409 |
|
410 |
-
|
|
|
411 |
|
412 |
-
|
413 |
|
|
|
414 |
|
415 |
-
def rot_aa(aa, rot):
|
416 |
-
"""Rotate axis angle parameters."""
|
417 |
-
# pose parameters
|
418 |
-
R = np.array([
|
419 |
-
[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
420 |
-
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
421 |
-
[0, 0, 1],
|
422 |
-
])
|
423 |
-
# find the rotation of the body in camera frame
|
424 |
-
per_rdg, _ = cv2.Rodrigues(aa)
|
425 |
-
# apply the global rotation to the global orientation
|
426 |
-
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
427 |
-
aa = (resrot.T)[0]
|
428 |
-
return aa
|
429 |
|
|
|
430 |
|
431 |
-
|
432 |
-
"""Flip rgb images or masks.
|
433 |
-
channels come last, e.g. (256,256,3).
|
434 |
-
"""
|
435 |
-
img = np.fliplr(img)
|
436 |
-
return img
|
437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
elif len(kp) == 49:
|
447 |
-
if is_smpl:
|
448 |
-
flipped_parts = constants.SMPL_J49_FLIP_PERM
|
449 |
-
else:
|
450 |
-
flipped_parts = constants.J49_FLIP_PERM
|
451 |
-
kp = kp[flipped_parts]
|
452 |
-
kp[:, 0] = -kp[:, 0]
|
453 |
-
return kp
|
454 |
-
|
455 |
-
|
456 |
-
def flip_pose(pose):
|
457 |
-
"""Flip pose.
|
458 |
-
The flipping is based on SMPL parameters.
|
459 |
-
"""
|
460 |
-
flipped_parts = constants.SMPL_POSE_FLIP_PERM
|
461 |
-
pose = pose[flipped_parts]
|
462 |
-
# we also negate the second and the third dimension of the axis-angle
|
463 |
-
pose[1::3] = -pose[1::3]
|
464 |
-
pose[2::3] = -pose[2::3]
|
465 |
-
return pose
|
466 |
-
|
467 |
-
|
468 |
-
def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
|
469 |
-
# Normalize keypoints between -1, 1
|
470 |
-
if not inv:
|
471 |
-
ratio = 1.0 / crop_size
|
472 |
-
kp_2d = 2.0 * kp_2d * ratio - 1.0
|
473 |
-
else:
|
474 |
-
ratio = 1.0 / crop_size
|
475 |
-
kp_2d = (kp_2d + 1.0) / (2 * ratio)
|
476 |
-
|
477 |
-
return kp_2d
|
478 |
-
|
479 |
-
|
480 |
-
def visualize_landmarks(image, joints, color):
|
481 |
-
|
482 |
-
img_w, img_h = image.shape[:2]
|
483 |
-
|
484 |
-
for joint in joints:
|
485 |
-
image = cv2.circle(image, (int(joint[0] * img_w), int(joint[1] * img_h)), 5, color)
|
486 |
-
|
487 |
-
return image
|
488 |
-
|
489 |
-
|
490 |
-
def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
|
491 |
-
"""
|
492 |
-
param joints: [num_joints, 3]
|
493 |
-
param joints_vis: [num_joints, 3]
|
494 |
-
return: target, target_weight(1: visible, 0: invisible)
|
495 |
-
"""
|
496 |
-
num_joints = joints.shape[0]
|
497 |
-
device = joints.device
|
498 |
-
cur_device = torch.device(device.type, device.index)
|
499 |
-
if not hasattr(heatmap_size, "__len__"):
|
500 |
-
# width height
|
501 |
-
heatmap_size = [heatmap_size, heatmap_size]
|
502 |
-
assert len(heatmap_size) == 2
|
503 |
-
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
504 |
-
if joints_vis is not None:
|
505 |
-
target_weight[:, 0] = joints_vis[:, 0]
|
506 |
-
target = torch.zeros(
|
507 |
-
(num_joints, heatmap_size[1], heatmap_size[0]),
|
508 |
-
dtype=torch.float32,
|
509 |
-
device=cur_device,
|
510 |
)
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
for joint_id in range(num_joints):
|
515 |
-
mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5)
|
516 |
-
mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5)
|
517 |
-
# Check that any part of the gaussian is in-bounds
|
518 |
-
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
519 |
-
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
520 |
-
if (ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] or br[0] < 0 or br[1] < 0):
|
521 |
-
# If not, just return the image as is
|
522 |
-
target_weight[joint_id] = 0
|
523 |
-
continue
|
524 |
-
|
525 |
-
# # Generate gaussian
|
526 |
-
size = 2 * tmp_size + 1
|
527 |
-
# x = np.arange(0, size, 1, np.float32)
|
528 |
-
# y = x[:, np.newaxis]
|
529 |
-
# x0 = y0 = size // 2
|
530 |
-
# # The gaussian is not normalized, we want the center value to equal 1
|
531 |
-
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
532 |
-
# g = torch.from_numpy(g.astype(np.float32))
|
533 |
-
|
534 |
-
x = torch.arange(0, size, dtype=torch.float32, device=cur_device)
|
535 |
-
y = x.unsqueeze(-1)
|
536 |
-
x0 = y0 = size // 2
|
537 |
-
# The gaussian is not normalized, we want the center value to equal 1
|
538 |
-
g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
|
539 |
-
|
540 |
-
# Usable gaussian range
|
541 |
-
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
|
542 |
-
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
|
543 |
-
# Image range
|
544 |
-
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
|
545 |
-
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
|
546 |
-
|
547 |
-
v = target_weight[joint_id]
|
548 |
-
if v > 0.5:
|
549 |
-
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
550 |
-
|
551 |
-
return target, target_weight
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import torch.nn.functional as F
|
|
|
|
|
6 |
from PIL import Image
|
|
|
|
|
7 |
from lib.pymafx.core import constants
|
8 |
+
|
9 |
+
from rembg import remove
|
10 |
+
from rembg.session_factory import new_session
|
11 |
from torchvision import transforms
|
12 |
+
from kornia.geometry.transform import get_affine_matrix2d, warp_affine
|
13 |
|
14 |
|
15 |
def transform_to_tensor(res, mean=None, std=None, is_tensor=False):
|
|
|
23 |
return transforms.Compose(all_ops)
|
24 |
|
25 |
|
26 |
+
def get_affine_matrix_wh(w1, h1, w2, h2):
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0)
|
29 |
+
center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0)
|
30 |
+
scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0)
|
31 |
+
M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
|
32 |
|
33 |
return M
|
34 |
|
35 |
|
36 |
+
def get_affine_matrix_box(boxes, w2, h2):
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
# boxes [left, top, right, bottom]
|
39 |
+
width = boxes[:, 2] - boxes[:, 0] #(N,)
|
40 |
+
height = boxes[:, 3] - boxes[:, 1] #(N,)
|
41 |
+
center = torch.tensor(
|
42 |
+
[(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]
|
43 |
+
).T #(N,2)
|
44 |
+
scale = torch.min(torch.tensor([w2 / width, h2 / height]),
|
45 |
+
dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2)
|
46 |
+
transl = torch.tensor([w2 / 2.0 - center[:, 0], h2 / 2.0 - center[:, 1]]).unsqueeze(0) #(N,2)
|
47 |
+
M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.]))
|
48 |
|
|
|
|
|
|
|
49 |
return M
|
50 |
|
51 |
|
52 |
def load_img(img_file):
|
53 |
|
54 |
img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
|
55 |
+
|
56 |
+
# considering 16-bit image
|
57 |
+
if img.dtype == np.uint16:
|
58 |
+
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
59 |
+
|
60 |
if len(img.shape) == 2:
|
61 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
62 |
|
|
|
65 |
else:
|
66 |
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
|
67 |
|
68 |
+
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2]
|
69 |
|
70 |
|
71 |
def get_keypoints(image):
|
|
|
72 |
def collect_xyv(x, body=True):
|
73 |
lmk = x.landmark
|
74 |
all_lmks = []
|
|
|
80 |
mp_holistic = mp.solutions.holistic
|
81 |
|
82 |
with mp_holistic.Holistic(
|
83 |
+
static_image_mode=True,
|
84 |
+
model_complexity=2,
|
85 |
) as holistic:
|
86 |
results = holistic.process(image)
|
87 |
|
|
|
89 |
|
90 |
result = {}
|
91 |
result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps
|
92 |
+
result["lhand"] = collect_xyv(
|
93 |
+
results.left_hand_landmarks, False
|
94 |
+
) if results.left_hand_landmarks else fake_kps
|
95 |
+
result["rhand"] = collect_xyv(
|
96 |
+
results.right_hand_landmarks, False
|
97 |
+
) if results.right_hand_landmarks else fake_kps
|
98 |
+
result["face"] = collect_xyv(
|
99 |
+
results.face_landmarks, False
|
100 |
+
) if results.face_landmarks else fake_kps
|
101 |
|
102 |
return result
|
103 |
|
|
|
106 |
|
107 |
# image [3,512,512]
|
108 |
|
109 |
+
item = {
|
110 |
+
'img_body':
|
111 |
+
F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]
|
112 |
+
}
|
113 |
|
114 |
for part in ['lhand', 'rhand', 'face']:
|
115 |
kp2d = landmarks[part]
|
116 |
kp2d_valid = kp2d[kp2d[:, 3] > 0.]
|
117 |
if len(kp2d_valid) > 0:
|
118 |
+
bbox = [
|
119 |
+
min(kp2d_valid[:, 0]),
|
120 |
+
min(kp2d_valid[:, 1]),
|
121 |
+
max(kp2d_valid[:, 0]),
|
122 |
+
max(kp2d_valid[:, 1])
|
123 |
+
]
|
124 |
center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
|
125 |
scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
|
126 |
|
|
|
151 |
return item
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
def remove_floats(mask):
|
155 |
|
156 |
# 1. find all the contours
|
|
|
169 |
return new_mask
|
170 |
|
171 |
|
172 |
+
def process_image(img_file, hps_type, single, input_res, detector):
|
173 |
|
174 |
+
img_raw, (in_height, in_width) = load_img(img_file)
|
175 |
+
tgt_res = input_res * 2
|
176 |
+
M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res)
|
177 |
+
img_square = warp_affine(
|
178 |
+
img_raw,
|
179 |
+
M_square[:, :2], (tgt_res, ) * 2,
|
180 |
+
mode='bilinear',
|
181 |
+
padding_mode='zeros',
|
182 |
+
align_corners=True
|
183 |
+
)
|
184 |
|
185 |
# detection for bbox
|
186 |
+
predictions = detector(img_square / 255.)[0]
|
|
|
|
|
187 |
|
188 |
if single:
|
189 |
top_score = predictions["scores"][predictions["labels"] == 1].max()
|
190 |
human_ids = torch.where(predictions["scores"] == top_score)[0]
|
191 |
else:
|
192 |
+
human_ids = torch.logical_and(predictions["labels"] == 1,
|
193 |
+
predictions["scores"] > 0.9).nonzero().squeeze(1)
|
194 |
|
195 |
boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy()
|
196 |
masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy()
|
197 |
|
198 |
+
M_crop = get_affine_matrix_box(boxes, input_res, input_res)
|
|
|
|
|
|
|
199 |
|
200 |
img_icon_lst = []
|
201 |
img_crop_lst = []
|
202 |
img_hps_lst = []
|
203 |
img_mask_lst = []
|
|
|
204 |
landmark_lst = []
|
205 |
hands_visibility_lst = []
|
206 |
img_pymafx_lst = []
|
207 |
|
208 |
uncrop_param = {
|
|
|
|
|
209 |
"ori_shape": [in_height, in_width],
|
210 |
"box_shape": [input_res, input_res],
|
211 |
+
"square_shape": [tgt_res, tgt_res],
|
212 |
+
"M_square": M_square,
|
213 |
+
"M_crop": M_crop
|
214 |
}
|
215 |
|
216 |
for idx in range(len(boxes)):
|
|
|
221 |
else:
|
222 |
mask_detection = masks[0] * 0.
|
223 |
|
224 |
+
img_square_rgba = torch.cat(
|
225 |
+
[img_square.squeeze(0).permute(1, 2, 0),
|
226 |
+
torch.tensor(mask_detection < 0.4) * 255],
|
227 |
+
dim=2
|
228 |
+
)
|
229 |
+
|
230 |
+
img_crop = warp_affine(
|
231 |
+
img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2),
|
232 |
+
M_crop[idx:idx + 1, :2], (input_res, ) * 2,
|
233 |
+
mode='bilinear',
|
234 |
+
padding_mode='zeros',
|
235 |
+
align_corners=True
|
236 |
+
).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)
|
237 |
+
|
238 |
+
# get accurate person segmentation mask
|
239 |
img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net"))
|
240 |
img_mask = remove_floats(img_rembg[:, :, [3]])
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
mean_icon = std_icon = (0.5, 0.5, 0.5)
|
243 |
img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8)
|
244 |
+
img_icon = transform_to_tensor(512, mean_icon, std_icon)(
|
245 |
+
Image.fromarray(img_np)
|
246 |
+
) * torch.tensor(img_mask).permute(2, 0, 1)
|
247 |
+
img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN,
|
248 |
+
constants.IMG_NORM_STD)(Image.fromarray(img_np))
|
249 |
|
250 |
landmarks = get_keypoints(img_np)
|
251 |
|
252 |
+
# get hands visibility
|
253 |
+
hands_visibility = [True, True]
|
254 |
+
if landmarks['lhand'][:, -1].mean() == 0.:
|
255 |
+
hands_visibility[0] = False
|
256 |
+
if landmarks['rhand'][:, -1].mean() == 0.:
|
257 |
+
hands_visibility[1] = False
|
258 |
+
hands_visibility_lst.append(hands_visibility)
|
259 |
+
|
260 |
if hps_type == 'pymafx':
|
261 |
img_pymafx_lst.append(
|
262 |
get_pymafx(
|
263 |
+
transform_to_tensor(512, constants.IMG_NORM_MEAN,
|
264 |
+
constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks
|
265 |
+
)
|
266 |
+
)
|
267 |
|
268 |
img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0)
|
269 |
img_icon_lst.append(img_icon)
|
270 |
img_hps_lst.append(img_hps)
|
271 |
img_mask_lst.append(torch.tensor(img_mask[..., 0]))
|
|
|
272 |
landmark_lst.append(landmarks['body'])
|
273 |
|
274 |
+
# required image tensors / arrays
|
275 |
+
|
276 |
+
# img_icon (tensor): (-1, 1), [3,512,512]
|
277 |
+
# img_hps (tensor): (-2.11, 2.44), [3,224,224]
|
278 |
+
|
279 |
+
# img_np (array): (0, 255), [512,512,3]
|
280 |
+
# img_rembg (array): (0, 255), [512,512,4]
|
281 |
+
# img_mask (array): (0, 1), [512,512,1]
|
282 |
+
# img_crop (array): (0, 255), [512,512,4]
|
283 |
|
284 |
return_dict = {
|
285 |
+
"img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res]
|
286 |
+
"img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res]
|
287 |
+
"img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res]
|
288 |
+
"img_raw": img_raw, #[1, 3, H, W]
|
289 |
+
"img_mask": torch.stack(img_mask_lst).float(), #[N, res, res]
|
290 |
"uncrop_param": uncrop_param,
|
291 |
+
"landmark": torch.stack(landmark_lst), #[N, 33, 4]
|
292 |
"hands_visibility": hands_visibility_lst,
|
293 |
}
|
294 |
|
|
|
310 |
return return_dict
|
311 |
|
312 |
|
313 |
+
def blend_rgb_norm(norms, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
+
# norms [N, 3, res, res]
|
316 |
+
masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
|
317 |
+
norm_mask = F.interpolate(
|
318 |
+
torch.cat([norms, masks], dim=1).detach(),
|
319 |
+
size=data["uncrop_param"]["box_shape"],
|
320 |
+
mode="bilinear",
|
321 |
+
align_corners=False
|
322 |
+
)
|
323 |
+
final = data["img_raw"].type_as(norm_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
+
for idx in range(len(norms)):
|
|
|
|
|
326 |
|
327 |
+
norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0
|
328 |
+
mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1)
|
329 |
|
330 |
+
norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
|
331 |
+
mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
|
332 |
|
333 |
+
final = final * (1.0 - mask_ori) + norm_ori * mask_ori
|
334 |
|
335 |
+
return final.detach().cpu()
|
336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
+
def unwrap(image, uncrop_param, idx):
|
339 |
|
340 |
+
device = image.device
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
+
img_square = warp_affine(
|
343 |
+
image,
|
344 |
+
torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device),
|
345 |
+
uncrop_param["square_shape"],
|
346 |
+
mode='bilinear',
|
347 |
+
padding_mode='zeros',
|
348 |
+
align_corners=True
|
349 |
+
)
|
350 |
|
351 |
+
img_ori = warp_affine(
|
352 |
+
img_square,
|
353 |
+
torch.inverse(uncrop_param["M_square"])[:, :2].to(device),
|
354 |
+
uncrop_param["ori_shape"],
|
355 |
+
mode='bilinear',
|
356 |
+
padding_mode='zeros',
|
357 |
+
align_corners=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
)
|
359 |
|
360 |
+
return img_ori
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/common/libmesh/inside_mesh.py
CHANGED
@@ -5,7 +5,7 @@ from .triangle_hash import TriangleHash as _TriangleHash
|
|
5 |
def check_mesh_contains(mesh, points, hash_resolution=512):
|
6 |
intersector = MeshIntersector(mesh, hash_resolution)
|
7 |
contains, hole_points = intersector.query(points)
|
8 |
-
return contains,
|
9 |
|
10 |
|
11 |
class MeshIntersector:
|
@@ -25,8 +25,7 @@ class MeshIntersector:
|
|
25 |
# assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
|
26 |
|
27 |
triangles2d = triangles[:, :, :2]
|
28 |
-
self._tri_intersector2d = TriangleIntersector2d(
|
29 |
-
triangles2d, resolution)
|
30 |
|
31 |
def query(self, points):
|
32 |
# Rescale points
|
@@ -38,8 +37,7 @@ class MeshIntersector:
|
|
38 |
|
39 |
# cull points outside of the axis aligned bounding box
|
40 |
# this avoids running ray tests unless points are close
|
41 |
-
inside_aabb = np.all(
|
42 |
-
(0 <= points) & (points <= self.resolution), axis=1)
|
43 |
if not inside_aabb.any():
|
44 |
return contains, hole_points
|
45 |
|
@@ -48,14 +46,14 @@ class MeshIntersector:
|
|
48 |
points = points[mask]
|
49 |
|
50 |
# Compute intersection depth and check order
|
51 |
-
points_indices, tri_indices = self._tri_intersector2d.query(
|
52 |
-
points[:, :2])
|
53 |
|
54 |
triangles_intersect = self._triangles[tri_indices]
|
55 |
points_intersect = points[points_indices]
|
56 |
|
57 |
depth_intersect, abs_n_2 = self.compute_intersection_depth(
|
58 |
-
points_intersect, triangles_intersect
|
|
|
59 |
|
60 |
# Count number of intersections in both directions
|
61 |
smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
|
@@ -73,7 +71,7 @@ class MeshIntersector:
|
|
73 |
# print('Warning: contains1 != contains2 for some points.')
|
74 |
contains[mask] = (contains1 & contains2)
|
75 |
hole_points[mask] = np.logical_xor(contains1, contains2)
|
76 |
-
return contains,
|
77 |
|
78 |
def compute_intersection_depth(self, points, triangles):
|
79 |
t1 = triangles[:, 0, :]
|
@@ -150,7 +148,7 @@ class TriangleIntersector2d:
|
|
150 |
|
151 |
sum_uv = u + v
|
152 |
contains[mask] = (
|
153 |
-
(0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA)
|
154 |
-
|
155 |
)
|
156 |
return contains
|
|
|
5 |
def check_mesh_contains(mesh, points, hash_resolution=512):
|
6 |
intersector = MeshIntersector(mesh, hash_resolution)
|
7 |
contains, hole_points = intersector.query(points)
|
8 |
+
return contains, hole_points
|
9 |
|
10 |
|
11 |
class MeshIntersector:
|
|
|
25 |
# assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5))
|
26 |
|
27 |
triangles2d = triangles[:, :, :2]
|
28 |
+
self._tri_intersector2d = TriangleIntersector2d(triangles2d, resolution)
|
|
|
29 |
|
30 |
def query(self, points):
|
31 |
# Rescale points
|
|
|
37 |
|
38 |
# cull points outside of the axis aligned bounding box
|
39 |
# this avoids running ray tests unless points are close
|
40 |
+
inside_aabb = np.all((0 <= points) & (points <= self.resolution), axis=1)
|
|
|
41 |
if not inside_aabb.any():
|
42 |
return contains, hole_points
|
43 |
|
|
|
46 |
points = points[mask]
|
47 |
|
48 |
# Compute intersection depth and check order
|
49 |
+
points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2])
|
|
|
50 |
|
51 |
triangles_intersect = self._triangles[tri_indices]
|
52 |
points_intersect = points[points_indices]
|
53 |
|
54 |
depth_intersect, abs_n_2 = self.compute_intersection_depth(
|
55 |
+
points_intersect, triangles_intersect
|
56 |
+
)
|
57 |
|
58 |
# Count number of intersections in both directions
|
59 |
smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2
|
|
|
71 |
# print('Warning: contains1 != contains2 for some points.')
|
72 |
contains[mask] = (contains1 & contains2)
|
73 |
hole_points[mask] = np.logical_xor(contains1, contains2)
|
74 |
+
return contains, hole_points
|
75 |
|
76 |
def compute_intersection_depth(self, points, triangles):
|
77 |
t1 = triangles[:, 0, :]
|
|
|
148 |
|
149 |
sum_uv = u + v
|
150 |
contains[mask] = (
|
151 |
+
(0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) &
|
152 |
+
(sum_uv < abs_detA)
|
153 |
)
|
154 |
return contains
|
lib/common/libmesh/setup.py
CHANGED
@@ -2,7 +2,4 @@ from setuptools import setup
|
|
2 |
from Cython.Build import cythonize
|
3 |
import numpy
|
4 |
|
5 |
-
|
6 |
-
setup(name = 'libmesh',
|
7 |
-
ext_modules = cythonize("*.pyx"),
|
8 |
-
include_dirs=[numpy.get_include()])
|
|
|
2 |
from Cython.Build import cythonize
|
3 |
import numpy
|
4 |
|
5 |
+
setup(name='libmesh', ext_modules=cythonize("*.pyx"), include_dirs=[numpy.get_include()])
|
|
|
|
|
|
lib/common/libvoxelize/setup.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
from setuptools import setup
|
2 |
from Cython.Build import cythonize
|
3 |
|
4 |
-
setup(name
|
5 |
-
ext_modules = cythonize("*.pyx"))
|
|
|
1 |
from setuptools import setup
|
2 |
from Cython.Build import cythonize
|
3 |
|
4 |
+
setup(name='libvoxelize', ext_modules=cythonize("*.pyx"))
|
|
lib/common/local_affine.py
CHANGED
@@ -16,7 +16,6 @@ from lib.common.train_util import init_loss
|
|
16 |
|
17 |
# reference: https://github.com/wuhaozhe/pytorch-nicp
|
18 |
class LocalAffine(nn.Module):
|
19 |
-
|
20 |
def __init__(self, num_points, batch_size=1, edges=None):
|
21 |
'''
|
22 |
specify the number of points, the number of points should be constant across the batch
|
@@ -26,8 +25,14 @@ class LocalAffine(nn.Module):
|
|
26 |
add additional pooling on top of w matrix
|
27 |
'''
|
28 |
super(LocalAffine, self).__init__()
|
29 |
-
self.A = nn.Parameter(
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
self.edges = edges
|
32 |
self.num_points = num_points
|
33 |
|
@@ -38,24 +43,23 @@ class LocalAffine(nn.Module):
|
|
38 |
'''
|
39 |
if self.edges is None:
|
40 |
raise Exception("edges cannot be none when calculate stiff")
|
41 |
-
idx1 = self.edges[:, 0]
|
42 |
-
idx2 = self.edges[:, 1]
|
43 |
affine_weight = torch.cat((self.A, self.b), dim=3)
|
44 |
-
w1 = torch.index_select(affine_weight, dim=1, index=
|
45 |
-
w2 = torch.index_select(affine_weight, dim=1, index=
|
46 |
w_diff = (w1 - w2)**2
|
47 |
w_rigid = (torch.linalg.det(self.A) - 1.0)**2
|
48 |
return w_diff, w_rigid
|
49 |
|
50 |
def forward(self, x):
|
51 |
'''
|
52 |
-
x should have shape of B * N * 3
|
53 |
'''
|
54 |
x = x.unsqueeze(3)
|
55 |
out_x = torch.matmul(self.A, x)
|
56 |
out_x = out_x + self.b
|
57 |
-
stiffness, rigid = self.stiffness()
|
58 |
out_x.squeeze_(3)
|
|
|
|
|
59 |
return out_x, stiffness, rigid
|
60 |
|
61 |
|
@@ -75,10 +79,16 @@ def register(target_mesh, src_mesh, device):
|
|
75 |
tgt_mesh = trimesh2meshes(target_mesh).to(device)
|
76 |
src_verts = src_mesh.verts_padded().clone()
|
77 |
|
78 |
-
local_affine_model = LocalAffine(
|
79 |
-
|
|
|
|
|
80 |
|
81 |
-
optimizer_cloth = torch.optim.Adam(
|
|
|
|
|
|
|
|
|
82 |
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
83 |
optimizer_cloth,
|
84 |
mode="min",
|
@@ -90,28 +100,27 @@ def register(target_mesh, src_mesh, device):
|
|
90 |
|
91 |
losses = init_loss()
|
92 |
|
93 |
-
loop_cloth = tqdm(range(
|
94 |
|
95 |
for i in loop_cloth:
|
96 |
|
97 |
optimizer_cloth.zero_grad()
|
98 |
|
99 |
-
deformed_verts, stiffness, rigid = local_affine_model(src_verts)
|
100 |
src_mesh = src_mesh.update_padded(deformed_verts)
|
101 |
|
102 |
# losses for laplacian, edge, normal consistency
|
103 |
update_mesh_shape_prior_losses(src_mesh, losses)
|
104 |
|
105 |
losses["cloth"]["value"] = chamfer_distance(
|
106 |
-
x=src_mesh.verts_padded(),
|
107 |
-
|
108 |
-
|
109 |
-
losses["stiffness"]["value"] = torch.mean(stiffness)
|
110 |
losses["rigid"]["value"] = torch.mean(rigid)
|
111 |
|
112 |
# Weighted sum of the losses
|
113 |
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
114 |
-
pbar_desc = "Register SMPL-X
|
115 |
|
116 |
for k in losses.keys():
|
117 |
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
|
@@ -119,7 +128,7 @@ def register(target_mesh, src_mesh, device):
|
|
119 |
losses[k]["value"] * losses[k]["weight"]
|
120 |
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
|
121 |
|
122 |
-
pbar_desc += f"
|
123 |
loop_cloth.set_description(pbar_desc)
|
124 |
|
125 |
# update params
|
@@ -131,6 +140,7 @@ def register(target_mesh, src_mesh, device):
|
|
131 |
src_mesh.verts_packed().detach().squeeze(0).cpu(),
|
132 |
src_mesh.faces_packed().detach().squeeze(0).cpu(),
|
133 |
process=False,
|
134 |
-
maintains_order=True
|
|
|
135 |
|
136 |
return final
|
|
|
16 |
|
17 |
# reference: https://github.com/wuhaozhe/pytorch-nicp
|
18 |
class LocalAffine(nn.Module):
|
|
|
19 |
def __init__(self, num_points, batch_size=1, edges=None):
|
20 |
'''
|
21 |
specify the number of points, the number of points should be constant across the batch
|
|
|
25 |
add additional pooling on top of w matrix
|
26 |
'''
|
27 |
super(LocalAffine, self).__init__()
|
28 |
+
self.A = nn.Parameter(
|
29 |
+
torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)
|
30 |
+
)
|
31 |
+
self.b = nn.Parameter(
|
32 |
+
torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(
|
33 |
+
batch_size, num_points, 1, 1
|
34 |
+
)
|
35 |
+
)
|
36 |
self.edges = edges
|
37 |
self.num_points = num_points
|
38 |
|
|
|
43 |
'''
|
44 |
if self.edges is None:
|
45 |
raise Exception("edges cannot be none when calculate stiff")
|
|
|
|
|
46 |
affine_weight = torch.cat((self.A, self.b), dim=3)
|
47 |
+
w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0])
|
48 |
+
w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1])
|
49 |
w_diff = (w1 - w2)**2
|
50 |
w_rigid = (torch.linalg.det(self.A) - 1.0)**2
|
51 |
return w_diff, w_rigid
|
52 |
|
53 |
def forward(self, x):
|
54 |
'''
|
55 |
+
x should have shape of B * N * 3 * 1
|
56 |
'''
|
57 |
x = x.unsqueeze(3)
|
58 |
out_x = torch.matmul(self.A, x)
|
59 |
out_x = out_x + self.b
|
|
|
60 |
out_x.squeeze_(3)
|
61 |
+
stiffness, rigid = self.stiffness()
|
62 |
+
|
63 |
return out_x, stiffness, rigid
|
64 |
|
65 |
|
|
|
79 |
tgt_mesh = trimesh2meshes(target_mesh).to(device)
|
80 |
src_verts = src_mesh.verts_padded().clone()
|
81 |
|
82 |
+
local_affine_model = LocalAffine(
|
83 |
+
src_mesh.verts_padded().shape[1],
|
84 |
+
src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
|
85 |
+
).to(device)
|
86 |
|
87 |
+
optimizer_cloth = torch.optim.Adam(
|
88 |
+
[{
|
89 |
+
'params': local_affine_model.parameters()
|
90 |
+
}], lr=1e-2, amsgrad=True
|
91 |
+
)
|
92 |
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
93 |
optimizer_cloth,
|
94 |
mode="min",
|
|
|
100 |
|
101 |
losses = init_loss()
|
102 |
|
103 |
+
loop_cloth = tqdm(range(100))
|
104 |
|
105 |
for i in loop_cloth:
|
106 |
|
107 |
optimizer_cloth.zero_grad()
|
108 |
|
109 |
+
deformed_verts, stiffness, rigid = local_affine_model(x=src_verts)
|
110 |
src_mesh = src_mesh.update_padded(deformed_verts)
|
111 |
|
112 |
# losses for laplacian, edge, normal consistency
|
113 |
update_mesh_shape_prior_losses(src_mesh, losses)
|
114 |
|
115 |
losses["cloth"]["value"] = chamfer_distance(
|
116 |
+
x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded()
|
117 |
+
)[0]
|
118 |
+
losses["stiff"]["value"] = torch.mean(stiffness)
|
|
|
119 |
losses["rigid"]["value"] = torch.mean(rigid)
|
120 |
|
121 |
# Weighted sum of the losses
|
122 |
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
123 |
+
pbar_desc = "Register SMPL-X -> d-BiNI -- "
|
124 |
|
125 |
for k in losses.keys():
|
126 |
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
|
|
|
128 |
losses[k]["value"] * losses[k]["weight"]
|
129 |
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
|
130 |
|
131 |
+
pbar_desc += f"TOTAL: {cloth_loss:.3f}"
|
132 |
loop_cloth.set_description(pbar_desc)
|
133 |
|
134 |
# update params
|
|
|
140 |
src_mesh.verts_packed().detach().squeeze(0).cpu(),
|
141 |
src_mesh.faces_packed().detach().squeeze(0).cpu(),
|
142 |
process=False,
|
143 |
+
maintains_order=True
|
144 |
+
)
|
145 |
|
146 |
return final
|
lib/common/render.py
CHANGED
@@ -31,7 +31,8 @@ from pytorch3d.renderer import (
|
|
31 |
)
|
32 |
from pytorch3d.renderer.mesh import TexturesVertex
|
33 |
from pytorch3d.structures import Meshes
|
34 |
-
from lib.dataset.mesh_util import get_visibility
|
|
|
35 |
|
36 |
import lib.common.render_utils as util
|
37 |
import torch
|
@@ -74,20 +75,23 @@ def query_color(verts, faces, image, device):
|
|
74 |
|
75 |
(xy, z) = verts.split([2, 1], dim=1)
|
76 |
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
|
77 |
-
uv = xy.unsqueeze(0).unsqueeze(2)
|
78 |
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
|
79 |
colors = (
|
80 |
-
(
|
81 |
-
|
|
|
|
|
|
|
82 |
colors[visibility == 0.0] = (
|
83 |
(Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) *
|
84 |
-
0.5 * 255.0
|
|
|
85 |
|
86 |
return colors.detach().cpu()
|
87 |
|
88 |
|
89 |
class cleanShader(torch.nn.Module):
|
90 |
-
|
91 |
def __init__(self, blend_params=None):
|
92 |
super().__init__()
|
93 |
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
@@ -103,7 +107,6 @@ class cleanShader(torch.nn.Module):
|
|
103 |
|
104 |
|
105 |
class Render:
|
106 |
-
|
107 |
def __init__(self, size=512, device=torch.device("cuda:0")):
|
108 |
self.device = device
|
109 |
self.size = size
|
@@ -119,21 +122,30 @@ class Render:
|
|
119 |
|
120 |
self.cam_pos = {
|
121 |
"frontback":
|
122 |
-
torch.tensor(
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
"four":
|
127 |
-
torch.tensor(
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
133 |
"around":
|
134 |
-
torch.tensor(
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
137 |
}
|
138 |
|
139 |
self.type = "color"
|
@@ -153,8 +165,8 @@ class Render:
|
|
153 |
|
154 |
R, T = look_at_view_transform(
|
155 |
eye=self.cam_pos[type][idx],
|
156 |
-
at=((0, self.mesh_y_center, 0),),
|
157 |
-
up=((0, 1, 0),),
|
158 |
)
|
159 |
|
160 |
cameras = FoVOrthographicCameras(
|
@@ -167,7 +179,7 @@ class Render:
|
|
167 |
min_y=-100.0,
|
168 |
max_x=100.0,
|
169 |
min_x=-100.0,
|
170 |
-
scale_xyz=(self.scale * np.ones(3),) * len(R),
|
171 |
)
|
172 |
|
173 |
return cameras
|
@@ -202,15 +214,17 @@ class Render:
|
|
202 |
cull_backfaces=True,
|
203 |
)
|
204 |
|
205 |
-
self.silhouetteRas = MeshRasterizer(
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
209 |
|
210 |
elif type == "pointcloud":
|
211 |
-
self.raster_settings_pcd = PointsRasterizationSettings(
|
212 |
-
|
213 |
-
|
214 |
|
215 |
self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd)
|
216 |
self.renderer = PointsRenderer(
|
@@ -230,8 +244,12 @@ class Render:
|
|
230 |
V_lst = []
|
231 |
F_lst = []
|
232 |
for V, F in zip(verts, faces):
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
self.meshes = Meshes(V_lst, F_lst).to(self.device)
|
236 |
else:
|
237 |
# array or tensor
|
@@ -248,7 +266,8 @@ class Render:
|
|
248 |
# texture only support single mesh
|
249 |
if len(self.meshes) == 1:
|
250 |
self.meshes.textures = TexturesVertex(
|
251 |
-
verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5
|
|
|
252 |
|
253 |
def get_image(self, cam_type="frontback", type="rgb", bg="gray"):
|
254 |
|
@@ -260,7 +279,8 @@ class Render:
|
|
260 |
|
261 |
current_mesh = self.meshes[mesh_id]
|
262 |
current_mesh.textures = TexturesVertex(
|
263 |
-
verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
|
|
|
264 |
|
265 |
if type == "depth":
|
266 |
fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type])))
|
@@ -276,7 +296,7 @@ class Render:
|
|
276 |
print(f"unknown {type}")
|
277 |
|
278 |
if cam_type == 'frontback':
|
279 |
-
images[1] = torch.flip(images[1], dims=(-1,))
|
280 |
|
281 |
# images [N_render, 3, res, res]
|
282 |
img_lst.append(images.unsqueeze(1))
|
@@ -287,9 +307,8 @@ class Render:
|
|
287 |
return list(meshes)
|
288 |
|
289 |
def get_rendered_video_multi(self, data, save_path):
|
290 |
-
|
291 |
-
width = data["img_raw"].shape[
|
292 |
-
height = data["img_raw"].shape[0]
|
293 |
|
294 |
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
295 |
video = cv2.VideoWriter(
|
@@ -302,14 +321,15 @@ class Render:
|
|
302 |
pbar = tqdm(range(len(self.meshes)))
|
303 |
pbar.set_description(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue"))
|
304 |
|
305 |
-
mesh_renders = []
|
306 |
|
307 |
# render all the normals
|
308 |
for mesh_id in pbar:
|
309 |
|
310 |
current_mesh = self.meshes[mesh_id]
|
311 |
current_mesh.textures = TexturesVertex(
|
312 |
-
verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
|
|
|
313 |
|
314 |
norm_lst = []
|
315 |
|
@@ -320,21 +340,33 @@ class Render:
|
|
320 |
self.init_renderer(batch_cams, "mesh", "gray")
|
321 |
|
322 |
norm_lst.append(
|
323 |
-
self.renderer(current_mesh.extend(len(batch_cams_idx))
|
324 |
-
|
|
|
325 |
mesh_renders.append(torch.cat(norm_lst).detach().cpu())
|
326 |
|
327 |
# generate video frame by frame
|
328 |
pbar = tqdm(range(len(self.cam_pos["around"])))
|
329 |
pbar.set_description(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue"))
|
|
|
330 |
for cam_id in pbar:
|
331 |
-
img_raw = data["img_raw"]
|
332 |
num_obj = len(mesh_renders) // 2
|
333 |
-
img_smpl = blend_rgb_norm(
|
334 |
-
|
|
|
|
|
|
|
|
|
335 |
|
336 |
-
top_img = cv2.resize(
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
video.write(final_img[:, :, ::-1])
|
339 |
|
340 |
video.release()
|
|
|
31 |
)
|
32 |
from pytorch3d.renderer.mesh import TexturesVertex
|
33 |
from pytorch3d.structures import Meshes
|
34 |
+
from lib.dataset.mesh_util import get_visibility
|
35 |
+
from lib.common.imutils import blend_rgb_norm
|
36 |
|
37 |
import lib.common.render_utils as util
|
38 |
import torch
|
|
|
75 |
|
76 |
(xy, z) = verts.split([2, 1], dim=1)
|
77 |
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
|
78 |
+
uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
|
79 |
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
|
80 |
colors = (
|
81 |
+
(
|
82 |
+
torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :,
|
83 |
+
0].permute(1, 0) + 1.0
|
84 |
+
) * 0.5 * 255.0
|
85 |
+
)
|
86 |
colors[visibility == 0.0] = (
|
87 |
(Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) *
|
88 |
+
0.5 * 255.0
|
89 |
+
)[visibility == 0.0]
|
90 |
|
91 |
return colors.detach().cpu()
|
92 |
|
93 |
|
94 |
class cleanShader(torch.nn.Module):
|
|
|
95 |
def __init__(self, blend_params=None):
|
96 |
super().__init__()
|
97 |
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
|
|
107 |
|
108 |
|
109 |
class Render:
|
|
|
110 |
def __init__(self, size=512, device=torch.device("cuda:0")):
|
111 |
self.device = device
|
112 |
self.size = size
|
|
|
122 |
|
123 |
self.cam_pos = {
|
124 |
"frontback":
|
125 |
+
torch.tensor(
|
126 |
+
[
|
127 |
+
(0, self.mesh_y_center, self.dis),
|
128 |
+
(0, self.mesh_y_center, -self.dis),
|
129 |
+
]
|
130 |
+
),
|
131 |
"four":
|
132 |
+
torch.tensor(
|
133 |
+
[
|
134 |
+
(0, self.mesh_y_center, self.dis),
|
135 |
+
(self.dis, self.mesh_y_center, 0),
|
136 |
+
(0, self.mesh_y_center, -self.dis),
|
137 |
+
(-self.dis, self.mesh_y_center, 0),
|
138 |
+
]
|
139 |
+
),
|
140 |
"around":
|
141 |
+
torch.tensor(
|
142 |
+
[
|
143 |
+
(
|
144 |
+
100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center,
|
145 |
+
100.0 * math.sin(np.pi / 180 * angle)
|
146 |
+
) for angle in range(0, 360, self.step)
|
147 |
+
]
|
148 |
+
)
|
149 |
}
|
150 |
|
151 |
self.type = "color"
|
|
|
165 |
|
166 |
R, T = look_at_view_transform(
|
167 |
eye=self.cam_pos[type][idx],
|
168 |
+
at=((0, self.mesh_y_center, 0), ),
|
169 |
+
up=((0, 1, 0), ),
|
170 |
)
|
171 |
|
172 |
cameras = FoVOrthographicCameras(
|
|
|
179 |
min_y=-100.0,
|
180 |
max_x=100.0,
|
181 |
min_x=-100.0,
|
182 |
+
scale_xyz=(self.scale * np.ones(3), ) * len(R),
|
183 |
)
|
184 |
|
185 |
return cameras
|
|
|
214 |
cull_backfaces=True,
|
215 |
)
|
216 |
|
217 |
+
self.silhouetteRas = MeshRasterizer(
|
218 |
+
cameras=camera, raster_settings=self.raster_settings_silhouette
|
219 |
+
)
|
220 |
+
self.renderer = MeshRenderer(
|
221 |
+
rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
|
222 |
+
)
|
223 |
|
224 |
elif type == "pointcloud":
|
225 |
+
self.raster_settings_pcd = PointsRasterizationSettings(
|
226 |
+
image_size=self.size, radius=0.006, points_per_pixel=10
|
227 |
+
)
|
228 |
|
229 |
self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd)
|
230 |
self.renderer = PointsRenderer(
|
|
|
244 |
V_lst = []
|
245 |
F_lst = []
|
246 |
for V, F in zip(verts, faces):
|
247 |
+
if not torch.is_tensor(V):
|
248 |
+
V_lst.append(torch.tensor(V).float().to(self.device))
|
249 |
+
F_lst.append(torch.tensor(F).long().to(self.device))
|
250 |
+
else:
|
251 |
+
V_lst.append(V.float().to(self.device))
|
252 |
+
F_lst.append(F.long().to(self.device))
|
253 |
self.meshes = Meshes(V_lst, F_lst).to(self.device)
|
254 |
else:
|
255 |
# array or tensor
|
|
|
266 |
# texture only support single mesh
|
267 |
if len(self.meshes) == 1:
|
268 |
self.meshes.textures = TexturesVertex(
|
269 |
+
verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5
|
270 |
+
)
|
271 |
|
272 |
def get_image(self, cam_type="frontback", type="rgb", bg="gray"):
|
273 |
|
|
|
279 |
|
280 |
current_mesh = self.meshes[mesh_id]
|
281 |
current_mesh.textures = TexturesVertex(
|
282 |
+
verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
|
283 |
+
)
|
284 |
|
285 |
if type == "depth":
|
286 |
fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type])))
|
|
|
296 |
print(f"unknown {type}")
|
297 |
|
298 |
if cam_type == 'frontback':
|
299 |
+
images[1] = torch.flip(images[1], dims=(-1, ))
|
300 |
|
301 |
# images [N_render, 3, res, res]
|
302 |
img_lst.append(images.unsqueeze(1))
|
|
|
307 |
return list(meshes)
|
308 |
|
309 |
def get_rendered_video_multi(self, data, save_path):
|
310 |
+
|
311 |
+
height, width = data["img_raw"].shape[2:]
|
|
|
312 |
|
313 |
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
314 |
video = cv2.VideoWriter(
|
|
|
321 |
pbar = tqdm(range(len(self.meshes)))
|
322 |
pbar.set_description(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue"))
|
323 |
|
324 |
+
mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh]
|
325 |
|
326 |
# render all the normals
|
327 |
for mesh_id in pbar:
|
328 |
|
329 |
current_mesh = self.meshes[mesh_id]
|
330 |
current_mesh.textures = TexturesVertex(
|
331 |
+
verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5
|
332 |
+
)
|
333 |
|
334 |
norm_lst = []
|
335 |
|
|
|
340 |
self.init_renderer(batch_cams, "mesh", "gray")
|
341 |
|
342 |
norm_lst.append(
|
343 |
+
self.renderer(current_mesh.extend(len(batch_cams_idx))
|
344 |
+
)[..., :3].permute(0, 3, 1, 2)
|
345 |
+
)
|
346 |
mesh_renders.append(torch.cat(norm_lst).detach().cpu())
|
347 |
|
348 |
# generate video frame by frame
|
349 |
pbar = tqdm(range(len(self.cam_pos["around"])))
|
350 |
pbar.set_description(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue"))
|
351 |
+
|
352 |
for cam_id in pbar:
|
353 |
+
img_raw = data["img_raw"]
|
354 |
num_obj = len(mesh_renders) // 2
|
355 |
+
img_smpl = blend_rgb_norm(
|
356 |
+
(torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data
|
357 |
+
)
|
358 |
+
img_cloth = blend_rgb_norm(
|
359 |
+
(torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data
|
360 |
+
)
|
361 |
|
362 |
+
top_img = cv2.resize(
|
363 |
+
torch.cat([img_raw, img_smpl],
|
364 |
+
dim=-1).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8),
|
365 |
+
(width, height // 2)
|
366 |
+
)
|
367 |
+
final_img = np.concatenate(
|
368 |
+
[top_img, img_cloth.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)], axis=0
|
369 |
+
)
|
370 |
video.write(final_img[:, :, ::-1])
|
371 |
|
372 |
video.release()
|
lib/common/render_utils.py
CHANGED
@@ -25,9 +25,7 @@ from pytorch3d.renderer.mesh import rasterize_meshes
|
|
25 |
Tensor = NewType("Tensor", torch.Tensor)
|
26 |
|
27 |
|
28 |
-
def solid_angles(points: Tensor,
|
29 |
-
triangles: Tensor,
|
30 |
-
thresh: float = 1e-8) -> Tensor:
|
31 |
"""Compute solid angle between the input points and triangles
|
32 |
Follows the method described in:
|
33 |
The Solid Angle of a Plane Triangle
|
@@ -55,9 +53,7 @@ def solid_angles(points: Tensor,
|
|
55 |
norms = torch.norm(centered_tris, dim=-1)
|
56 |
|
57 |
# Should be BxQxFx3
|
58 |
-
cross_prod = torch.cross(centered_tris[:, :, :, 1],
|
59 |
-
centered_tris[:, :, :, 2],
|
60 |
-
dim=-1)
|
61 |
# Should be BxQxF
|
62 |
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
|
63 |
del cross_prod
|
@@ -67,8 +63,10 @@ def solid_angles(points: Tensor,
|
|
67 |
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
68 |
del centered_tris
|
69 |
|
70 |
-
denominator = (
|
71 |
-
|
|
|
|
|
72 |
del dot01, dot12, dot02, norms
|
73 |
|
74 |
# Should be BxQ
|
@@ -80,9 +78,7 @@ def solid_angles(points: Tensor,
|
|
80 |
return 2 * solid_angle
|
81 |
|
82 |
|
83 |
-
def winding_numbers(points: Tensor,
|
84 |
-
triangles: Tensor,
|
85 |
-
thresh: float = 1e-8) -> Tensor:
|
86 |
"""Uses winding_numbers to compute inside/outside
|
87 |
Robust inside-outside segmentation using generalized winding numbers
|
88 |
Alec Jacobson,
|
@@ -109,8 +105,7 @@ def winding_numbers(points: Tensor,
|
|
109 |
"""
|
110 |
# The generalized winding number is the sum of solid angles of the point
|
111 |
# with respect to all triangles.
|
112 |
-
return (1 / (4 * math.pi) *
|
113 |
-
solid_angles(points, triangles, thresh=thresh).sum(dim=-1))
|
114 |
|
115 |
|
116 |
def batch_contains(verts, faces, points):
|
@@ -124,8 +119,7 @@ def batch_contains(verts, faces, points):
|
|
124 |
contains = torch.zeros(B, N)
|
125 |
|
126 |
for i in range(B):
|
127 |
-
contains[i] = torch.as_tensor(
|
128 |
-
trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
|
129 |
|
130 |
return 2.0 * (contains - 0.5)
|
131 |
|
@@ -155,8 +149,7 @@ def face_vertices(vertices, faces):
|
|
155 |
bs, nv = vertices.shape[:2]
|
156 |
bs, nf = faces.shape[:2]
|
157 |
device = vertices.device
|
158 |
-
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
|
159 |
-
nv)[:, None, None]
|
160 |
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
|
161 |
|
162 |
return vertices[faces.long()]
|
@@ -168,7 +161,6 @@ class Pytorch3dRasterizer(nn.Module):
|
|
168 |
x,y,z are in image space, normalized
|
169 |
can only render squared image now
|
170 |
"""
|
171 |
-
|
172 |
def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1):
|
173 |
"""
|
174 |
use fixed raster_settings for rendering faces
|
@@ -189,8 +181,7 @@ class Pytorch3dRasterizer(nn.Module):
|
|
189 |
def forward(self, vertices, faces, attributes=None):
|
190 |
fixed_vertices = vertices.clone()
|
191 |
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
|
192 |
-
meshes_screen = Meshes(verts=fixed_vertices.float(),
|
193 |
-
faces=faces.long())
|
194 |
raster_settings = self.raster_settings
|
195 |
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
196 |
meshes_screen,
|
@@ -204,8 +195,9 @@ class Pytorch3dRasterizer(nn.Module):
|
|
204 |
vismask = (pix_to_face > -1).float()
|
205 |
D = attributes.shape[-1]
|
206 |
attributes = attributes.clone()
|
207 |
-
attributes = attributes.view(
|
208 |
-
|
|
|
209 |
N, H, W, K, _ = bary_coords.shape
|
210 |
mask = pix_to_face == -1
|
211 |
pix_to_face = pix_to_face.clone()
|
@@ -213,8 +205,7 @@ class Pytorch3dRasterizer(nn.Module):
|
|
213 |
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
214 |
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
215 |
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
216 |
-
pixel_vals[mask] = 0
|
217 |
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
|
218 |
-
pixel_vals = torch.cat(
|
219 |
-
[pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
|
220 |
return pixel_vals
|
|
|
25 |
Tensor = NewType("Tensor", torch.Tensor)
|
26 |
|
27 |
|
28 |
+
def solid_angles(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor:
|
|
|
|
|
29 |
"""Compute solid angle between the input points and triangles
|
30 |
Follows the method described in:
|
31 |
The Solid Angle of a Plane Triangle
|
|
|
53 |
norms = torch.norm(centered_tris, dim=-1)
|
54 |
|
55 |
# Should be BxQxFx3
|
56 |
+
cross_prod = torch.cross(centered_tris[:, :, :, 1], centered_tris[:, :, :, 2], dim=-1)
|
|
|
|
|
57 |
# Should be BxQxF
|
58 |
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
|
59 |
del cross_prod
|
|
|
63 |
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
64 |
del centered_tris
|
65 |
|
66 |
+
denominator = (
|
67 |
+
norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + dot02 * norms[:, :, :, 1] +
|
68 |
+
dot12 * norms[:, :, :, 0]
|
69 |
+
)
|
70 |
del dot01, dot12, dot02, norms
|
71 |
|
72 |
# Should be BxQ
|
|
|
78 |
return 2 * solid_angle
|
79 |
|
80 |
|
81 |
+
def winding_numbers(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor:
|
|
|
|
|
82 |
"""Uses winding_numbers to compute inside/outside
|
83 |
Robust inside-outside segmentation using generalized winding numbers
|
84 |
Alec Jacobson,
|
|
|
105 |
"""
|
106 |
# The generalized winding number is the sum of solid angles of the point
|
107 |
# with respect to all triangles.
|
108 |
+
return (1 / (4 * math.pi) * solid_angles(points, triangles, thresh=thresh).sum(dim=-1))
|
|
|
109 |
|
110 |
|
111 |
def batch_contains(verts, faces, points):
|
|
|
119 |
contains = torch.zeros(B, N)
|
120 |
|
121 |
for i in range(B):
|
122 |
+
contains[i] = torch.as_tensor(trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
|
|
|
123 |
|
124 |
return 2.0 * (contains - 0.5)
|
125 |
|
|
|
149 |
bs, nv = vertices.shape[:2]
|
150 |
bs, nf = faces.shape[:2]
|
151 |
device = vertices.device
|
152 |
+
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
|
|
|
153 |
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
|
154 |
|
155 |
return vertices[faces.long()]
|
|
|
161 |
x,y,z are in image space, normalized
|
162 |
can only render squared image now
|
163 |
"""
|
|
|
164 |
def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1):
|
165 |
"""
|
166 |
use fixed raster_settings for rendering faces
|
|
|
181 |
def forward(self, vertices, faces, attributes=None):
|
182 |
fixed_vertices = vertices.clone()
|
183 |
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
|
184 |
+
meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
|
|
|
185 |
raster_settings = self.raster_settings
|
186 |
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
187 |
meshes_screen,
|
|
|
195 |
vismask = (pix_to_face > -1).float()
|
196 |
D = attributes.shape[-1]
|
197 |
attributes = attributes.clone()
|
198 |
+
attributes = attributes.view(
|
199 |
+
attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]
|
200 |
+
)
|
201 |
N, H, W, K, _ = bary_coords.shape
|
202 |
mask = pix_to_face == -1
|
203 |
pix_to_face = pix_to_face.clone()
|
|
|
205 |
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
206 |
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
207 |
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
208 |
+
pixel_vals[mask] = 0 # Replace masked values in output.
|
209 |
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
|
210 |
+
pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
|
|
|
211 |
return pixel_vals
|
lib/common/seg3d_lossless.py
CHANGED
@@ -31,7 +31,6 @@ logging.getLogger("lightning").setLevel(logging.ERROR)
|
|
31 |
|
32 |
|
33 |
class Seg3dLossless(nn.Module):
|
34 |
-
|
35 |
def __init__(
|
36 |
self,
|
37 |
query_func,
|
@@ -53,19 +52,14 @@ class Seg3dLossless(nn.Module):
|
|
53 |
"""
|
54 |
super().__init__()
|
55 |
self.query_func = query_func
|
56 |
-
self.register_buffer(
|
57 |
-
|
58 |
-
torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
|
59 |
-
self.register_buffer(
|
60 |
-
"b_max",
|
61 |
-
torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
|
62 |
|
63 |
# ti.init(arch=ti.cuda)
|
64 |
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
|
65 |
|
66 |
if type(resolutions[0]) is int:
|
67 |
-
resolutions = torch.tensor([(res, res, res)
|
68 |
-
for res in resolutions])
|
69 |
else:
|
70 |
resolutions = torch.tensor(resolutions)
|
71 |
self.register_buffer("resolutions", resolutions)
|
@@ -87,45 +81,36 @@ class Seg3dLossless(nn.Module):
|
|
87 |
), f"resolution {resolution} need to be odd becuase of align_corner."
|
88 |
|
89 |
# init first resolution
|
90 |
-
init_coords = create_grid3D(0,
|
91 |
-
|
92 |
-
steps=resolutions[0]) # [N, 3]
|
93 |
-
init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
|
94 |
-
1) # [bz, N, 3]
|
95 |
self.register_buffer("init_coords", init_coords)
|
96 |
|
97 |
# some useful tensors
|
98 |
calculated = torch.zeros(
|
99 |
-
(self.resolutions[-1][2], self.resolutions[-1][1],
|
100 |
-
self.resolutions[-1][0]),
|
101 |
dtype=torch.bool,
|
102 |
)
|
103 |
self.register_buffer("calculated", calculated)
|
104 |
|
105 |
-
gird8_offsets = (
|
106 |
-
torch.
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
114 |
self.register_buffer("gird8_offsets", gird8_offsets)
|
115 |
|
116 |
# smooth convs
|
117 |
-
self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
|
118 |
-
|
119 |
-
|
120 |
-
self.
|
121 |
-
out_channels=1,
|
122 |
-
kernel_size=5)
|
123 |
-
self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
|
124 |
-
out_channels=1,
|
125 |
-
kernel_size=7)
|
126 |
-
self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
|
127 |
-
out_channels=1,
|
128 |
-
kernel_size=9)
|
129 |
|
130 |
@torch.no_grad()
|
131 |
def batch_eval(self, coords, **kwargs):
|
@@ -144,7 +129,7 @@ class Seg3dLossless(nn.Module):
|
|
144 |
# query function
|
145 |
occupancys = self.query_func(**kwargs, points=coords2D)
|
146 |
if type(occupancys) is list:
|
147 |
-
occupancys = torch.stack(occupancys)
|
148 |
assert (
|
149 |
len(occupancys.size()) == 3
|
150 |
), "query_func should return a occupancy with shape of [bz, C, N]"
|
@@ -175,10 +160,9 @@ class Seg3dLossless(nn.Module):
|
|
175 |
|
176 |
# first step
|
177 |
if torch.equal(resolution, self.resolutions[0]):
|
178 |
-
coords = self.init_coords.clone()
|
179 |
occupancys = self.batch_eval(coords, **kwargs)
|
180 |
-
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
181 |
-
H, W)
|
182 |
if (occupancys > 0.5).sum() == 0:
|
183 |
# return F.interpolate(
|
184 |
# occupancys, size=(final_D, final_H, final_W),
|
@@ -239,23 +223,22 @@ class Seg3dLossless(nn.Module):
|
|
239 |
|
240 |
with torch.no_grad():
|
241 |
if torch.equal(resolution, self.resolutions[1]):
|
242 |
-
is_boundary = (self.smooth_conv9x9(is_boundary.float())
|
243 |
-
> 0)[0, 0]
|
244 |
elif torch.equal(resolution, self.resolutions[2]):
|
245 |
-
is_boundary = (self.smooth_conv7x7(is_boundary.float())
|
246 |
-
> 0)[0, 0]
|
247 |
else:
|
248 |
-
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
249 |
-
> 0)[0, 0]
|
250 |
|
251 |
coords_accum = coords_accum.long()
|
252 |
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
253 |
coords_accum[0, :, 0], ] = False
|
254 |
-
point_coords = (
|
255 |
-
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
|
|
259 |
|
260 |
R, C, D, H, W = occupancys.shape
|
261 |
|
@@ -269,13 +252,15 @@ class Seg3dLossless(nn.Module):
|
|
269 |
# put mask point predictions to the right places on the upsampled grid.
|
270 |
R, C, D, H, W = occupancys.shape
|
271 |
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
272 |
-
occupancys = (
|
273 |
-
|
|
|
|
|
|
|
274 |
|
275 |
with torch.no_grad():
|
276 |
voxels = coords / stride
|
277 |
-
coords_accum = torch.cat([voxels, coords_accum],
|
278 |
-
dim=1).unique(dim=1)
|
279 |
|
280 |
return occupancys[0, 0]
|
281 |
|
@@ -300,18 +285,16 @@ class Seg3dLossless(nn.Module):
|
|
300 |
|
301 |
# first step
|
302 |
if torch.equal(resolution, self.resolutions[0]):
|
303 |
-
coords = self.init_coords.clone()
|
304 |
occupancys = self.batch_eval(coords, **kwargs)
|
305 |
-
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
306 |
-
H, W)
|
307 |
|
308 |
if self.visualize:
|
309 |
self.plot(occupancys, coords, final_D, final_H, final_W)
|
310 |
|
311 |
with torch.no_grad():
|
312 |
coords_accum = coords / stride
|
313 |
-
calculated[coords[0, :, 2], coords[0, :, 1],
|
314 |
-
coords[0, :, 0]] = True
|
315 |
|
316 |
# next steps
|
317 |
else:
|
@@ -338,35 +321,34 @@ class Seg3dLossless(nn.Module):
|
|
338 |
|
339 |
with torch.no_grad():
|
340 |
# TODO
|
341 |
-
if self.use_shadow and torch.equal(resolution,
|
342 |
-
self.resolutions[-1]):
|
343 |
# larger z means smaller depth here
|
344 |
depth_res = resolution[2].item()
|
345 |
-
depth_index = torch.linspace(0,
|
346 |
-
depth_res
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
)[0] - 1)
|
355 |
shadow = depth_index < depth_index_max
|
356 |
is_boundary[shadow] = False
|
357 |
is_boundary = is_boundary[0, 0]
|
358 |
else:
|
359 |
-
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
360 |
-
> 0)[0, 0]
|
361 |
# is_boundary = is_boundary[0, 0]
|
362 |
|
363 |
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
364 |
coords_accum[0, :, 0], ] = False
|
365 |
-
point_coords = (
|
366 |
-
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
370 |
|
371 |
R, C, D, H, W = occupancys.shape
|
372 |
# interpolated value
|
@@ -388,28 +370,28 @@ class Seg3dLossless(nn.Module):
|
|
388 |
# put mask point predictions to the right places on the upsampled grid.
|
389 |
R, C, D, H, W = occupancys.shape
|
390 |
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
391 |
-
occupancys = (
|
392 |
-
|
|
|
|
|
|
|
393 |
|
394 |
with torch.no_grad():
|
395 |
# conflicts
|
396 |
-
conflicts = (
|
397 |
-
|
398 |
-
|
|
|
399 |
|
400 |
if self.visualize:
|
401 |
-
self.plot(occupancys, coords, final_D, final_H,
|
402 |
-
final_W)
|
403 |
|
404 |
voxels = coords / stride
|
405 |
-
coords_accum = torch.cat([voxels, coords_accum],
|
406 |
-
|
407 |
-
calculated[coords[0, :, 2], coords[0, :, 1],
|
408 |
-
coords[0, :, 0]] = True
|
409 |
|
410 |
while conflicts.sum() > 0:
|
411 |
-
if self.use_shadow and torch.equal(resolution,
|
412 |
-
self.resolutions[-1]):
|
413 |
break
|
414 |
|
415 |
with torch.no_grad():
|
@@ -426,25 +408,27 @@ class Seg3dLossless(nn.Module):
|
|
426 |
)
|
427 |
|
428 |
conflicts_boundary = (
|
429 |
-
(
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
|
|
448 |
|
449 |
if self.debug:
|
450 |
self.plot(
|
@@ -458,9 +442,10 @@ class Seg3dLossless(nn.Module):
|
|
458 |
|
459 |
coords = coords.unsqueeze(0)
|
460 |
point_coords = coords / stride
|
461 |
-
point_indices = (
|
462 |
-
|
463 |
-
|
|
|
464 |
|
465 |
R, C, D, H, W = occupancys.shape
|
466 |
# interpolated value
|
@@ -481,44 +466,37 @@ class Seg3dLossless(nn.Module):
|
|
481 |
|
482 |
with torch.no_grad():
|
483 |
# conflicts
|
484 |
-
conflicts = (
|
485 |
-
|
486 |
-
|
|
|
487 |
|
488 |
# put mask point predictions to the right places on the upsampled grid.
|
489 |
-
point_indices = point_indices.unsqueeze(1).expand(
|
490 |
-
|
491 |
-
|
492 |
-
|
|
|
|
|
493 |
|
494 |
with torch.no_grad():
|
495 |
voxels = coords / stride
|
496 |
-
coords_accum = torch.cat([voxels, coords_accum],
|
497 |
-
|
498 |
-
calculated[coords[0, :, 2], coords[0, :, 1],
|
499 |
-
coords[0, :, 0]] = True
|
500 |
|
501 |
if self.visualize:
|
502 |
this_stage_coords = torch.cat(this_stage_coords, dim=1)
|
503 |
-
self.plot(occupancys, this_stage_coords, final_D, final_H,
|
504 |
-
final_W)
|
505 |
|
506 |
return occupancys[0, 0]
|
507 |
|
508 |
-
def plot(self,
|
509 |
-
occupancys,
|
510 |
-
coords,
|
511 |
-
final_D,
|
512 |
-
final_H,
|
513 |
-
final_W,
|
514 |
-
title="",
|
515 |
-
**kwargs):
|
516 |
final = F.interpolate(
|
517 |
occupancys.float(),
|
518 |
size=(final_D, final_H, final_W),
|
519 |
mode="trilinear",
|
520 |
align_corners=True,
|
521 |
-
)
|
522 |
x = coords[0, :, 0].to("cpu")
|
523 |
y = coords[0, :, 1].to("cpu")
|
524 |
z = coords[0, :, 2].to("cpu")
|
@@ -548,20 +526,19 @@ class Seg3dLossless(nn.Module):
|
|
548 |
sdf_all = sdf.permute(2, 1, 0)
|
549 |
|
550 |
# shadow
|
551 |
-
grad_v = (sdf_all > 0.5) * torch.linspace(
|
552 |
-
|
553 |
-
|
554 |
-
0, resolution - 1, steps=resolution).to(sdf.device)
|
555 |
max_v, max_c = grad_v.max(dim=2)
|
556 |
shadow = grad_c > max_c.view(resolution, resolution, 1)
|
557 |
keep = (sdf_all > 0.5) & (~shadow)
|
558 |
|
559 |
-
p1 = keep.nonzero(as_tuple=False).t()
|
560 |
-
p2 = p1.clone()
|
561 |
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
|
562 |
-
p3 = p1.clone()
|
563 |
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
|
564 |
-
p4 = p1.clone()
|
565 |
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
|
566 |
|
567 |
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
|
@@ -569,10 +546,10 @@ class Seg3dLossless(nn.Module):
|
|
569 |
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
|
570 |
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
|
571 |
|
572 |
-
X = p1[0, :].long()
|
573 |
-
Y = p1[1, :].long()
|
574 |
-
Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (
|
575 |
-
|
576 |
Z = Z.clamp(0, resolution)
|
577 |
|
578 |
# normal
|
@@ -588,8 +565,7 @@ class Seg3dLossless(nn.Module):
|
|
588 |
|
589 |
@torch.no_grad()
|
590 |
def render_normal(self, resolution, X, Y, Z, norm):
|
591 |
-
image = torch.ones((1, 3, resolution, resolution),
|
592 |
-
dtype=torch.float32).to(norm.device)
|
593 |
color = (norm + 1) / 2.0
|
594 |
color = color.clamp(0, 1)
|
595 |
image[0, :, Y, X] = color.t()
|
@@ -617,9 +593,9 @@ class Seg3dLossless(nn.Module):
|
|
617 |
def export_mesh(self, occupancys):
|
618 |
|
619 |
final = occupancys[1:, 1:, 1:].contiguous()
|
620 |
-
|
621 |
verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5)
|
622 |
verts = verts[0].cpu().float()
|
623 |
-
faces = faces[0].cpu().long()[:,[0,2,1]]
|
624 |
-
|
625 |
return verts, faces
|
|
|
31 |
|
32 |
|
33 |
class Seg3dLossless(nn.Module):
|
|
|
34 |
def __init__(
|
35 |
self,
|
36 |
query_func,
|
|
|
52 |
"""
|
53 |
super().__init__()
|
54 |
self.query_func = query_func
|
55 |
+
self.register_buffer("b_min", torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
|
56 |
+
self.register_buffer("b_max", torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# ti.init(arch=ti.cuda)
|
59 |
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
|
60 |
|
61 |
if type(resolutions[0]) is int:
|
62 |
+
resolutions = torch.tensor([(res, res, res) for res in resolutions])
|
|
|
63 |
else:
|
64 |
resolutions = torch.tensor(resolutions)
|
65 |
self.register_buffer("resolutions", resolutions)
|
|
|
81 |
), f"resolution {resolution} need to be odd becuase of align_corner."
|
82 |
|
83 |
# init first resolution
|
84 |
+
init_coords = create_grid3D(0, resolutions[-1] - 1, steps=resolutions[0]) # [N, 3]
|
85 |
+
init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, 1) # [bz, N, 3]
|
|
|
|
|
|
|
86 |
self.register_buffer("init_coords", init_coords)
|
87 |
|
88 |
# some useful tensors
|
89 |
calculated = torch.zeros(
|
90 |
+
(self.resolutions[-1][2], self.resolutions[-1][1], self.resolutions[-1][0]),
|
|
|
91 |
dtype=torch.bool,
|
92 |
)
|
93 |
self.register_buffer("calculated", calculated)
|
94 |
|
95 |
+
gird8_offsets = (
|
96 |
+
torch.stack(
|
97 |
+
torch.meshgrid(
|
98 |
+
[
|
99 |
+
torch.tensor([-1, 0, 1]),
|
100 |
+
torch.tensor([-1, 0, 1]),
|
101 |
+
torch.tensor([-1, 0, 1]),
|
102 |
+
],
|
103 |
+
indexing="ij",
|
104 |
+
)
|
105 |
+
).int().view(3, -1).t()
|
106 |
+
) # [27, 3]
|
107 |
self.register_buffer("gird8_offsets", gird8_offsets)
|
108 |
|
109 |
# smooth convs
|
110 |
+
self.smooth_conv3x3 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=3)
|
111 |
+
self.smooth_conv5x5 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=5)
|
112 |
+
self.smooth_conv7x7 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=7)
|
113 |
+
self.smooth_conv9x9 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
@torch.no_grad()
|
116 |
def batch_eval(self, coords, **kwargs):
|
|
|
129 |
# query function
|
130 |
occupancys = self.query_func(**kwargs, points=coords2D)
|
131 |
if type(occupancys) is list:
|
132 |
+
occupancys = torch.stack(occupancys) # [bz, C, N]
|
133 |
assert (
|
134 |
len(occupancys.size()) == 3
|
135 |
), "query_func should return a occupancy with shape of [bz, C, N]"
|
|
|
160 |
|
161 |
# first step
|
162 |
if torch.equal(resolution, self.resolutions[0]):
|
163 |
+
coords = self.init_coords.clone() # torch.long
|
164 |
occupancys = self.batch_eval(coords, **kwargs)
|
165 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
|
|
|
166 |
if (occupancys > 0.5).sum() == 0:
|
167 |
# return F.interpolate(
|
168 |
# occupancys, size=(final_D, final_H, final_W),
|
|
|
223 |
|
224 |
with torch.no_grad():
|
225 |
if torch.equal(resolution, self.resolutions[1]):
|
226 |
+
is_boundary = (self.smooth_conv9x9(is_boundary.float()) > 0)[0, 0]
|
|
|
227 |
elif torch.equal(resolution, self.resolutions[2]):
|
228 |
+
is_boundary = (self.smooth_conv7x7(is_boundary.float()) > 0)[0, 0]
|
|
|
229 |
else:
|
230 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
|
|
|
231 |
|
232 |
coords_accum = coords_accum.long()
|
233 |
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
234 |
coords_accum[0, :, 0], ] = False
|
235 |
+
point_coords = (
|
236 |
+
is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
237 |
+
)
|
238 |
+
point_indices = (
|
239 |
+
point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
|
240 |
+
point_coords[:, :, 0]
|
241 |
+
)
|
242 |
|
243 |
R, C, D, H, W = occupancys.shape
|
244 |
|
|
|
252 |
# put mask point predictions to the right places on the upsampled grid.
|
253 |
R, C, D, H, W = occupancys.shape
|
254 |
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
255 |
+
occupancys = (
|
256 |
+
occupancys.reshape(R, C,
|
257 |
+
D * H * W).scatter_(2, point_indices,
|
258 |
+
occupancys_topk).view(R, C, D, H, W)
|
259 |
+
)
|
260 |
|
261 |
with torch.no_grad():
|
262 |
voxels = coords / stride
|
263 |
+
coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
|
|
|
264 |
|
265 |
return occupancys[0, 0]
|
266 |
|
|
|
285 |
|
286 |
# first step
|
287 |
if torch.equal(resolution, self.resolutions[0]):
|
288 |
+
coords = self.init_coords.clone() # torch.long
|
289 |
occupancys = self.batch_eval(coords, **kwargs)
|
290 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D, H, W)
|
|
|
291 |
|
292 |
if self.visualize:
|
293 |
self.plot(occupancys, coords, final_D, final_H, final_W)
|
294 |
|
295 |
with torch.no_grad():
|
296 |
coords_accum = coords / stride
|
297 |
+
calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
|
|
|
298 |
|
299 |
# next steps
|
300 |
else:
|
|
|
321 |
|
322 |
with torch.no_grad():
|
323 |
# TODO
|
324 |
+
if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
|
|
|
325 |
# larger z means smaller depth here
|
326 |
depth_res = resolution[2].item()
|
327 |
+
depth_index = torch.linspace(0, depth_res - 1,
|
328 |
+
steps=depth_res).type_as(occupancys.device)
|
329 |
+
depth_index_max = (
|
330 |
+
torch.max(
|
331 |
+
(occupancys > self.balance_value) * (depth_index + 1),
|
332 |
+
dim=-1,
|
333 |
+
keepdim=True,
|
334 |
+
)[0] - 1
|
335 |
+
)
|
|
|
336 |
shadow = depth_index < depth_index_max
|
337 |
is_boundary[shadow] = False
|
338 |
is_boundary = is_boundary[0, 0]
|
339 |
else:
|
340 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0]
|
|
|
341 |
# is_boundary = is_boundary[0, 0]
|
342 |
|
343 |
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
344 |
coords_accum[0, :, 0], ] = False
|
345 |
+
point_coords = (
|
346 |
+
is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
347 |
+
)
|
348 |
+
point_indices = (
|
349 |
+
point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
|
350 |
+
point_coords[:, :, 0]
|
351 |
+
)
|
352 |
|
353 |
R, C, D, H, W = occupancys.shape
|
354 |
# interpolated value
|
|
|
370 |
# put mask point predictions to the right places on the upsampled grid.
|
371 |
R, C, D, H, W = occupancys.shape
|
372 |
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
373 |
+
occupancys = (
|
374 |
+
occupancys.reshape(R, C,
|
375 |
+
D * H * W).scatter_(2, point_indices,
|
376 |
+
occupancys_topk).view(R, C, D, H, W)
|
377 |
+
)
|
378 |
|
379 |
with torch.no_grad():
|
380 |
# conflicts
|
381 |
+
conflicts = (
|
382 |
+
(occupancys_interp - self.balance_value) *
|
383 |
+
(occupancys_topk - self.balance_value) < 0
|
384 |
+
)[0, 0]
|
385 |
|
386 |
if self.visualize:
|
387 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
|
|
388 |
|
389 |
voxels = coords / stride
|
390 |
+
coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
|
391 |
+
calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
|
|
|
|
|
392 |
|
393 |
while conflicts.sum() > 0:
|
394 |
+
if self.use_shadow and torch.equal(resolution, self.resolutions[-1]):
|
|
|
395 |
break
|
396 |
|
397 |
with torch.no_grad():
|
|
|
408 |
)
|
409 |
|
410 |
conflicts_boundary = (
|
411 |
+
(
|
412 |
+
conflicts_coords.int() +
|
413 |
+
self.gird8_offsets.unsqueeze(1) * stride.int()
|
414 |
+
).reshape(-1, 3).long().unique(dim=0)
|
415 |
+
)
|
416 |
+
conflicts_boundary[:, 0] = conflicts_boundary[:, 0].clamp(
|
417 |
+
0,
|
418 |
+
calculated.size(2) - 1
|
419 |
+
)
|
420 |
+
conflicts_boundary[:, 1] = conflicts_boundary[:, 1].clamp(
|
421 |
+
0,
|
422 |
+
calculated.size(1) - 1
|
423 |
+
)
|
424 |
+
conflicts_boundary[:, 2] = conflicts_boundary[:, 2].clamp(
|
425 |
+
0,
|
426 |
+
calculated.size(0) - 1
|
427 |
+
)
|
428 |
+
|
429 |
+
coords = conflicts_boundary[calculated[conflicts_boundary[:, 2],
|
430 |
+
conflicts_boundary[:, 1],
|
431 |
+
conflicts_boundary[:, 0], ] == False]
|
432 |
|
433 |
if self.debug:
|
434 |
self.plot(
|
|
|
442 |
|
443 |
coords = coords.unsqueeze(0)
|
444 |
point_coords = coords / stride
|
445 |
+
point_indices = (
|
446 |
+
point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W +
|
447 |
+
point_coords[:, :, 0]
|
448 |
+
)
|
449 |
|
450 |
R, C, D, H, W = occupancys.shape
|
451 |
# interpolated value
|
|
|
466 |
|
467 |
with torch.no_grad():
|
468 |
# conflicts
|
469 |
+
conflicts = (
|
470 |
+
(occupancys_interp - self.balance_value) *
|
471 |
+
(occupancys_topk - self.balance_value) < 0
|
472 |
+
)[0, 0]
|
473 |
|
474 |
# put mask point predictions to the right places on the upsampled grid.
|
475 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
476 |
+
occupancys = (
|
477 |
+
occupancys.reshape(R, C,
|
478 |
+
D * H * W).scatter_(2, point_indices,
|
479 |
+
occupancys_topk).view(R, C, D, H, W)
|
480 |
+
)
|
481 |
|
482 |
with torch.no_grad():
|
483 |
voxels = coords / stride
|
484 |
+
coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1)
|
485 |
+
calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True
|
|
|
|
|
486 |
|
487 |
if self.visualize:
|
488 |
this_stage_coords = torch.cat(this_stage_coords, dim=1)
|
489 |
+
self.plot(occupancys, this_stage_coords, final_D, final_H, final_W)
|
|
|
490 |
|
491 |
return occupancys[0, 0]
|
492 |
|
493 |
+
def plot(self, occupancys, coords, final_D, final_H, final_W, title="", **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
final = F.interpolate(
|
495 |
occupancys.float(),
|
496 |
size=(final_D, final_H, final_W),
|
497 |
mode="trilinear",
|
498 |
align_corners=True,
|
499 |
+
) # here true is correct!
|
500 |
x = coords[0, :, 0].to("cpu")
|
501 |
y = coords[0, :, 1].to("cpu")
|
502 |
z = coords[0, :, 2].to("cpu")
|
|
|
526 |
sdf_all = sdf.permute(2, 1, 0)
|
527 |
|
528 |
# shadow
|
529 |
+
grad_v = (sdf_all > 0.5) * torch.linspace(resolution, 1, steps=resolution).to(sdf.device)
|
530 |
+
grad_c = torch.ones_like(sdf_all) * torch.linspace(0, resolution - 1,
|
531 |
+
steps=resolution).to(sdf.device)
|
|
|
532 |
max_v, max_c = grad_v.max(dim=2)
|
533 |
shadow = grad_c > max_c.view(resolution, resolution, 1)
|
534 |
keep = (sdf_all > 0.5) & (~shadow)
|
535 |
|
536 |
+
p1 = keep.nonzero(as_tuple=False).t() # [3, N]
|
537 |
+
p2 = p1.clone() # z
|
538 |
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
|
539 |
+
p3 = p1.clone() # y
|
540 |
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
|
541 |
+
p4 = p1.clone() # x
|
542 |
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
|
543 |
|
544 |
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
|
|
|
546 |
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
|
547 |
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
|
548 |
|
549 |
+
X = p1[0, :].long() # [N,]
|
550 |
+
Y = p1[1, :].long() # [N,]
|
551 |
+
Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (v2 - 0.5
|
552 |
+
) / (v2 - v1) # [N,]
|
553 |
Z = Z.clamp(0, resolution)
|
554 |
|
555 |
# normal
|
|
|
565 |
|
566 |
@torch.no_grad()
|
567 |
def render_normal(self, resolution, X, Y, Z, norm):
|
568 |
+
image = torch.ones((1, 3, resolution, resolution), dtype=torch.float32).to(norm.device)
|
|
|
569 |
color = (norm + 1) / 2.0
|
570 |
color = color.clamp(0, 1)
|
571 |
image[0, :, Y, X] = color.t()
|
|
|
593 |
def export_mesh(self, occupancys):
|
594 |
|
595 |
final = occupancys[1:, 1:, 1:].contiguous()
|
596 |
+
|
597 |
verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5)
|
598 |
verts = verts[0].cpu().float()
|
599 |
+
faces = faces[0].cpu().long()[:, [0, 2, 1]]
|
600 |
+
|
601 |
return verts, faces
|
lib/common/seg3d_utils.py
CHANGED
@@ -20,11 +20,7 @@ import torch.nn.functional as F
|
|
20 |
import matplotlib.pyplot as plt
|
21 |
|
22 |
|
23 |
-
def plot_mask2D(mask,
|
24 |
-
title="",
|
25 |
-
point_coords=None,
|
26 |
-
figsize=10,
|
27 |
-
point_marker_size=5):
|
28 |
'''
|
29 |
Simple plotting tool to show intermediate mask predictions and points
|
30 |
where PointRend is applied.
|
@@ -46,26 +42,19 @@ def plot_mask2D(mask,
|
|
46 |
plt.xlabel(W, fontsize=30)
|
47 |
plt.xticks([], [])
|
48 |
plt.yticks([], [])
|
49 |
-
plt.imshow(mask.detach(),
|
50 |
-
interpolation="nearest",
|
51 |
-
cmap=plt.get_cmap('gray'))
|
52 |
if point_coords is not None:
|
53 |
-
plt.scatter(
|
54 |
-
|
55 |
-
|
56 |
-
s=point_marker_size,
|
57 |
-
clip_on=True)
|
58 |
plt.xlim(-0.5, W - 0.5)
|
59 |
plt.ylim(H - 0.5, -0.5)
|
60 |
plt.show()
|
61 |
|
62 |
|
63 |
-
def plot_mask3D(
|
64 |
-
|
65 |
-
|
66 |
-
figsize=1500,
|
67 |
-
point_marker_size=8,
|
68 |
-
interactive=True):
|
69 |
'''
|
70 |
Simple plotting tool to show intermediate mask predictions and points
|
71 |
where PointRend is applied.
|
@@ -90,7 +79,8 @@ def plot_mask3D(mask=None,
|
|
90 |
|
91 |
# marching cube to find surface
|
92 |
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
93 |
-
mask, 0.5, gradient_direction='ascent'
|
|
|
94 |
|
95 |
# create a mesh
|
96 |
mesh = trimesh.Trimesh(verts, faces)
|
@@ -110,57 +100,49 @@ def plot_mask3D(mask=None,
|
|
110 |
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
|
111 |
vis_list.append(pc)
|
112 |
|
113 |
-
vp.show(*vis_list,
|
114 |
-
bg="white",
|
115 |
-
axes=1,
|
116 |
-
interactive=interactive,
|
117 |
-
azimuth=30,
|
118 |
-
elevation=30)
|
119 |
|
120 |
|
121 |
def create_grid3D(min, max, steps):
|
122 |
if type(min) is int:
|
123 |
-
min = (min, min, min)
|
124 |
if type(max) is int:
|
125 |
-
max = (max, max, max)
|
126 |
if type(steps) is int:
|
127 |
-
steps = (steps, steps, steps)
|
128 |
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
129 |
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
130 |
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
|
131 |
-
gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX],
|
132 |
-
|
133 |
-
coords =
|
134 |
-
gridD]) # [2, steps[0], steps[1], steps[2]]
|
135 |
-
coords = coords.view(3, -1).t() # [N, 3]
|
136 |
return coords
|
137 |
|
138 |
|
139 |
def create_grid2D(min, max, steps):
|
140 |
if type(min) is int:
|
141 |
-
min = (min, min)
|
142 |
if type(max) is int:
|
143 |
-
max = (max, max)
|
144 |
if type(steps) is int:
|
145 |
-
steps = (steps, steps)
|
146 |
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
147 |
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
148 |
girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij')
|
149 |
-
coords = torch.stack([gridW, girdH])
|
150 |
-
coords = coords.view(2, -1).t()
|
151 |
return coords
|
152 |
|
153 |
|
154 |
class SmoothConv2D(nn.Module):
|
155 |
-
|
156 |
def __init__(self, in_channels, out_channels, kernel_size=3):
|
157 |
super().__init__()
|
158 |
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
159 |
self.padding = (kernel_size - 1) // 2
|
160 |
|
161 |
weight = torch.ones(
|
162 |
-
(in_channels, out_channels, kernel_size, kernel_size),
|
163 |
-
|
164 |
self.register_buffer('weight', weight)
|
165 |
|
166 |
def forward(self, input):
|
@@ -168,53 +150,49 @@ class SmoothConv2D(nn.Module):
|
|
168 |
|
169 |
|
170 |
class SmoothConv3D(nn.Module):
|
171 |
-
|
172 |
def __init__(self, in_channels, out_channels, kernel_size=3):
|
173 |
super().__init__()
|
174 |
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
175 |
self.padding = (kernel_size - 1) // 2
|
176 |
|
177 |
weight = torch.ones(
|
178 |
-
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
179 |
-
|
180 |
self.register_buffer('weight', weight)
|
181 |
|
182 |
def forward(self, input):
|
183 |
return F.conv3d(input, self.weight, padding=self.padding)
|
184 |
|
185 |
|
186 |
-
def build_smooth_conv3D(in_channels=1,
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
padding=padding)
|
194 |
smooth_conv.weight.data = torch.ones(
|
195 |
-
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
196 |
-
|
197 |
smooth_conv.bias.data = torch.zeros(out_channels)
|
198 |
return smooth_conv
|
199 |
|
200 |
|
201 |
-
def build_smooth_conv2D(in_channels=1,
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
padding=padding)
|
209 |
smooth_conv.weight.data = torch.ones(
|
210 |
-
(in_channels, out_channels, kernel_size, kernel_size),
|
211 |
-
|
212 |
smooth_conv.bias.data = torch.zeros(out_channels)
|
213 |
return smooth_conv
|
214 |
|
215 |
|
216 |
-
def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
|
217 |
-
**kwargs):
|
218 |
"""
|
219 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
220 |
Args:
|
@@ -233,28 +211,21 @@ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
|
|
233 |
# d_step = 1.0 / float(D)
|
234 |
|
235 |
num_points = min(D * H * W, num_points)
|
236 |
-
point_scores, point_indices = torch.topk(
|
237 |
-
R, D * H * W),
|
238 |
-
|
239 |
-
|
240 |
-
point_coords = torch.zeros(R,
|
241 |
-
num_points,
|
242 |
-
3,
|
243 |
-
dtype=torch.float,
|
244 |
-
device=uncertainty_map.device)
|
245 |
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
246 |
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
247 |
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
248 |
-
point_coords[:, :, 0] = (point_indices % W).to(torch.float)
|
249 |
-
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float)
|
250 |
-
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float)
|
251 |
-
print(f"resolution {D} x {H} x {W}", point_scores.min(),
|
252 |
-
point_scores.max())
|
253 |
return point_indices, point_coords
|
254 |
|
255 |
|
256 |
-
def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
|
257 |
-
clip_min):
|
258 |
"""
|
259 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
260 |
Args:
|
@@ -276,28 +247,21 @@ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
|
|
276 |
uncertainty_map = uncertainty_map.view(D * H * W)
|
277 |
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
278 |
num_points = min(num_points, indices.size(0))
|
279 |
-
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
280 |
-
k=num_points,
|
281 |
-
dim=0)
|
282 |
point_indices = indices[point_indices].unsqueeze(0)
|
283 |
|
284 |
-
point_coords = torch.zeros(R,
|
285 |
-
num_points,
|
286 |
-
3,
|
287 |
-
dtype=torch.float,
|
288 |
-
device=uncertainty_map.device)
|
289 |
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
290 |
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
291 |
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
292 |
-
point_coords[:, :, 0] = (point_indices % W).to(torch.float)
|
293 |
-
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float)
|
294 |
-
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float)
|
295 |
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
|
296 |
return point_indices, point_coords
|
297 |
|
298 |
|
299 |
-
def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
|
300 |
-
**kwargs):
|
301 |
"""
|
302 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
303 |
Args:
|
@@ -315,14 +279,8 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
|
|
315 |
# w_step = 1.0 / float(W)
|
316 |
|
317 |
num_points = min(H * W, num_points)
|
318 |
-
point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
|
319 |
-
|
320 |
-
dim=1)
|
321 |
-
point_coords = torch.zeros(R,
|
322 |
-
num_points,
|
323 |
-
2,
|
324 |
-
dtype=torch.long,
|
325 |
-
device=uncertainty_map.device)
|
326 |
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
327 |
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
328 |
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
@@ -331,8 +289,7 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
|
|
331 |
return point_indices, point_coords
|
332 |
|
333 |
|
334 |
-
def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
|
335 |
-
clip_min):
|
336 |
"""
|
337 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
338 |
Args:
|
@@ -353,16 +310,10 @@ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
|
|
353 |
uncertainty_map = uncertainty_map.view(H * W)
|
354 |
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
355 |
num_points = min(num_points, indices.size(0))
|
356 |
-
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
357 |
-
k=num_points,
|
358 |
-
dim=0)
|
359 |
point_indices = indices[point_indices].unsqueeze(0)
|
360 |
|
361 |
-
point_coords = torch.zeros(R,
|
362 |
-
num_points,
|
363 |
-
2,
|
364 |
-
dtype=torch.long,
|
365 |
-
device=uncertainty_map.device)
|
366 |
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
367 |
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
368 |
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
@@ -388,7 +339,6 @@ def calculate_uncertainty(logits, classes=None, balance_value=0.5):
|
|
388 |
if logits.shape[1] == 1:
|
389 |
gt_class_logits = logits
|
390 |
else:
|
391 |
-
gt_class_logits = logits[
|
392 |
-
|
393 |
-
classes].unsqueeze(1)
|
394 |
return -torch.abs(gt_class_logits - balance_value)
|
|
|
20 |
import matplotlib.pyplot as plt
|
21 |
|
22 |
|
23 |
+
def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5):
|
|
|
|
|
|
|
|
|
24 |
'''
|
25 |
Simple plotting tool to show intermediate mask predictions and points
|
26 |
where PointRend is applied.
|
|
|
42 |
plt.xlabel(W, fontsize=30)
|
43 |
plt.xticks([], [])
|
44 |
plt.yticks([], [])
|
45 |
+
plt.imshow(mask.detach(), interpolation="nearest", cmap=plt.get_cmap('gray'))
|
|
|
|
|
46 |
if point_coords is not None:
|
47 |
+
plt.scatter(
|
48 |
+
x=point_coords[0], y=point_coords[1], color="red", s=point_marker_size, clip_on=True
|
49 |
+
)
|
|
|
|
|
50 |
plt.xlim(-0.5, W - 0.5)
|
51 |
plt.ylim(H - 0.5, -0.5)
|
52 |
plt.show()
|
53 |
|
54 |
|
55 |
+
def plot_mask3D(
|
56 |
+
mask=None, title="", point_coords=None, figsize=1500, point_marker_size=8, interactive=True
|
57 |
+
):
|
|
|
|
|
|
|
58 |
'''
|
59 |
Simple plotting tool to show intermediate mask predictions and points
|
60 |
where PointRend is applied.
|
|
|
79 |
|
80 |
# marching cube to find surface
|
81 |
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
82 |
+
mask, 0.5, gradient_direction='ascent'
|
83 |
+
)
|
84 |
|
85 |
# create a mesh
|
86 |
mesh = trimesh.Trimesh(verts, faces)
|
|
|
100 |
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
|
101 |
vis_list.append(pc)
|
102 |
|
103 |
+
vp.show(*vis_list, bg="white", axes=1, interactive=interactive, azimuth=30, elevation=30)
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
|
106 |
def create_grid3D(min, max, steps):
|
107 |
if type(min) is int:
|
108 |
+
min = (min, min, min) # (x, y, z)
|
109 |
if type(max) is int:
|
110 |
+
max = (max, max, max) # (x, y)
|
111 |
if type(steps) is int:
|
112 |
+
steps = (steps, steps, steps) # (x, y, z)
|
113 |
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
114 |
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
115 |
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
|
116 |
+
gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX], indexing='ij')
|
117 |
+
coords = torch.stack([gridW, girdH, gridD]) # [2, steps[0], steps[1], steps[2]]
|
118 |
+
coords = coords.view(3, -1).t() # [N, 3]
|
|
|
|
|
119 |
return coords
|
120 |
|
121 |
|
122 |
def create_grid2D(min, max, steps):
|
123 |
if type(min) is int:
|
124 |
+
min = (min, min) # (x, y)
|
125 |
if type(max) is int:
|
126 |
+
max = (max, max) # (x, y)
|
127 |
if type(steps) is int:
|
128 |
+
steps = (steps, steps) # (x, y)
|
129 |
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
130 |
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
131 |
girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij')
|
132 |
+
coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
|
133 |
+
coords = coords.view(2, -1).t() # [N, 2]
|
134 |
return coords
|
135 |
|
136 |
|
137 |
class SmoothConv2D(nn.Module):
|
|
|
138 |
def __init__(self, in_channels, out_channels, kernel_size=3):
|
139 |
super().__init__()
|
140 |
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
141 |
self.padding = (kernel_size - 1) // 2
|
142 |
|
143 |
weight = torch.ones(
|
144 |
+
(in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
|
145 |
+
) / (kernel_size**2)
|
146 |
self.register_buffer('weight', weight)
|
147 |
|
148 |
def forward(self, input):
|
|
|
150 |
|
151 |
|
152 |
class SmoothConv3D(nn.Module):
|
|
|
153 |
def __init__(self, in_channels, out_channels, kernel_size=3):
|
154 |
super().__init__()
|
155 |
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
156 |
self.padding = (kernel_size - 1) // 2
|
157 |
|
158 |
weight = torch.ones(
|
159 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
|
160 |
+
) / (kernel_size**3)
|
161 |
self.register_buffer('weight', weight)
|
162 |
|
163 |
def forward(self, input):
|
164 |
return F.conv3d(input, self.weight, padding=self.padding)
|
165 |
|
166 |
|
167 |
+
def build_smooth_conv3D(in_channels=1, out_channels=1, kernel_size=3, padding=1):
|
168 |
+
smooth_conv = torch.nn.Conv3d(
|
169 |
+
in_channels=in_channels,
|
170 |
+
out_channels=out_channels,
|
171 |
+
kernel_size=kernel_size,
|
172 |
+
padding=padding
|
173 |
+
)
|
|
|
174 |
smooth_conv.weight.data = torch.ones(
|
175 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32
|
176 |
+
) / (kernel_size**3)
|
177 |
smooth_conv.bias.data = torch.zeros(out_channels)
|
178 |
return smooth_conv
|
179 |
|
180 |
|
181 |
+
def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1):
|
182 |
+
smooth_conv = torch.nn.Conv2d(
|
183 |
+
in_channels=in_channels,
|
184 |
+
out_channels=out_channels,
|
185 |
+
kernel_size=kernel_size,
|
186 |
+
padding=padding
|
187 |
+
)
|
|
|
188 |
smooth_conv.weight.data = torch.ones(
|
189 |
+
(in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32
|
190 |
+
) / (kernel_size**2)
|
191 |
smooth_conv.bias.data = torch.zeros(out_channels)
|
192 |
return smooth_conv
|
193 |
|
194 |
|
195 |
+
def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, **kwargs):
|
|
|
196 |
"""
|
197 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
198 |
Args:
|
|
|
211 |
# d_step = 1.0 / float(D)
|
212 |
|
213 |
num_points = min(D * H * W, num_points)
|
214 |
+
point_scores, point_indices = torch.topk(
|
215 |
+
uncertainty_map.view(R, D * H * W), k=num_points, dim=1
|
216 |
+
)
|
217 |
+
point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device)
|
|
|
|
|
|
|
|
|
|
|
218 |
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
219 |
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
220 |
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
221 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
222 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
223 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
224 |
+
print(f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
|
|
|
225 |
return point_indices, point_coords
|
226 |
|
227 |
|
228 |
+
def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, clip_min):
|
|
|
229 |
"""
|
230 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
231 |
Args:
|
|
|
247 |
uncertainty_map = uncertainty_map.view(D * H * W)
|
248 |
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
249 |
num_points = min(num_points, indices.size(0))
|
250 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0)
|
|
|
|
|
251 |
point_indices = indices[point_indices].unsqueeze(0)
|
252 |
|
253 |
+
point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device)
|
|
|
|
|
|
|
|
|
254 |
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
255 |
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
256 |
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
257 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
258 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
259 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
260 |
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
|
261 |
return point_indices, point_coords
|
262 |
|
263 |
|
264 |
+
def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, **kwargs):
|
|
|
265 |
"""
|
266 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
267 |
Args:
|
|
|
279 |
# w_step = 1.0 / float(W)
|
280 |
|
281 |
num_points = min(H * W, num_points)
|
282 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)
|
283 |
+
point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
285 |
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
286 |
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
|
|
289 |
return point_indices, point_coords
|
290 |
|
291 |
|
292 |
+
def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, clip_min):
|
|
|
293 |
"""
|
294 |
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
295 |
Args:
|
|
|
310 |
uncertainty_map = uncertainty_map.view(H * W)
|
311 |
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
312 |
num_points = min(num_points, indices.size(0))
|
313 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0)
|
|
|
|
|
314 |
point_indices = indices[point_indices].unsqueeze(0)
|
315 |
|
316 |
+
point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device)
|
|
|
|
|
|
|
|
|
317 |
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
318 |
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
319 |
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
|
|
339 |
if logits.shape[1] == 1:
|
340 |
gt_class_logits = logits
|
341 |
else:
|
342 |
+
gt_class_logits = logits[torch.arange(logits.shape[0], device=logits.device),
|
343 |
+
classes].unsqueeze(1)
|
|
|
344 |
return -torch.abs(gt_class_logits - balance_value)
|
lib/common/train_util.py
CHANGED
@@ -14,63 +14,62 @@
|
|
14 |
#
|
15 |
# Contact: [email protected]
|
16 |
|
17 |
-
import yaml
|
18 |
-
import os.path as osp
|
19 |
import torch
|
20 |
-
import numpy as np
|
21 |
from ..dataset.mesh_util import *
|
22 |
from ..net.geometry import orthogonal
|
23 |
-
import cv2, PIL
|
24 |
-
from tqdm import tqdm
|
25 |
-
import os
|
26 |
from termcolor import colored
|
27 |
import pytorch_lightning as pl
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
30 |
def init_loss():
|
31 |
|
32 |
losses = {
|
33 |
-
|
34 |
"cloth": {
|
35 |
"weight": 1e3,
|
36 |
"value": 0.0
|
37 |
},
|
38 |
-
|
39 |
-
"
|
40 |
"weight": 1e5,
|
41 |
"value": 0.0
|
42 |
},
|
43 |
-
|
44 |
"rigid": {
|
45 |
"weight": 1e5,
|
46 |
"value": 0.0
|
47 |
},
|
48 |
-
|
49 |
"edge": {
|
50 |
"weight": 0,
|
51 |
"value": 0.0
|
52 |
},
|
53 |
-
|
54 |
"nc": {
|
55 |
"weight": 0,
|
56 |
"value": 0.0
|
57 |
},
|
58 |
-
|
59 |
-
"
|
60 |
"weight": 1e2,
|
61 |
"value": 0.0
|
62 |
},
|
63 |
-
|
64 |
"normal": {
|
65 |
"weight": 1e0,
|
66 |
"value": 0.0
|
67 |
},
|
68 |
-
|
69 |
"silhouette": {
|
70 |
"weight": 1e0,
|
71 |
"value": 0.0
|
72 |
},
|
73 |
-
|
74 |
"joint": {
|
75 |
"weight": 5e0,
|
76 |
"value": 0.0
|
@@ -81,7 +80,6 @@ def init_loss():
|
|
81 |
|
82 |
|
83 |
class SubTrainer(pl.Trainer):
|
84 |
-
|
85 |
def save_checkpoint(self, filepath, weights_only=False):
|
86 |
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
87 |
Args:
|
@@ -101,214 +99,6 @@ class SubTrainer(pl.Trainer):
|
|
101 |
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
|
102 |
|
103 |
|
104 |
-
def rename(old_dict, old_name, new_name):
|
105 |
-
new_dict = {}
|
106 |
-
for key, value in zip(old_dict.keys(), old_dict.values()):
|
107 |
-
new_key = key if key != old_name else new_name
|
108 |
-
new_dict[new_key] = old_dict[key]
|
109 |
-
return new_dict
|
110 |
-
|
111 |
-
|
112 |
-
def load_normal_networks(model, normal_path):
|
113 |
-
|
114 |
-
pretrained_dict = torch.load(
|
115 |
-
normal_path,
|
116 |
-
map_location=model.device)["state_dict"]
|
117 |
-
model_dict = model.state_dict()
|
118 |
-
|
119 |
-
# 1. filter out unnecessary keys
|
120 |
-
pretrained_dict = {
|
121 |
-
k: v
|
122 |
-
for k, v in pretrained_dict.items()
|
123 |
-
if k in model_dict and v.shape == model_dict[k].shape
|
124 |
-
}
|
125 |
-
|
126 |
-
# # 2. overwrite entries in the existing state dict
|
127 |
-
model_dict.update(pretrained_dict)
|
128 |
-
# 3. load the new state dict
|
129 |
-
model.load_state_dict(model_dict)
|
130 |
-
|
131 |
-
del pretrained_dict
|
132 |
-
del model_dict
|
133 |
-
|
134 |
-
print(colored(f"Resume Normal weights from {normal_path}", "green"))
|
135 |
-
|
136 |
-
|
137 |
-
def load_networks(model, mlp_path, normal_path=None):
|
138 |
-
|
139 |
-
model_dict = model.state_dict()
|
140 |
-
main_dict = {}
|
141 |
-
normal_dict = {}
|
142 |
-
|
143 |
-
# MLP part loading
|
144 |
-
if os.path.exists(mlp_path) and mlp_path.endswith("ckpt"):
|
145 |
-
main_dict = torch.load(
|
146 |
-
mlp_path,
|
147 |
-
map_location=model.device)["state_dict"]
|
148 |
-
|
149 |
-
main_dict = {
|
150 |
-
k: v
|
151 |
-
for k, v in main_dict.items()
|
152 |
-
if k in model_dict and v.shape == model_dict[k].shape and (
|
153 |
-
"reconEngine" not in k) and ("normal_filter" not in k) and (
|
154 |
-
"voxelization" not in k)
|
155 |
-
}
|
156 |
-
print(colored(f"Resume MLP weights from {mlp_path}", "green"))
|
157 |
-
|
158 |
-
# normal network part loading
|
159 |
-
if normal_path is not None and os.path.exists(normal_path) and normal_path.endswith("ckpt"):
|
160 |
-
normal_dict = torch.load(
|
161 |
-
normal_path,
|
162 |
-
map_location=model.device)["state_dict"]
|
163 |
-
|
164 |
-
for key in normal_dict.keys():
|
165 |
-
normal_dict = rename(normal_dict, key,
|
166 |
-
key.replace("netG", "netG.normal_filter"))
|
167 |
-
|
168 |
-
normal_dict = {
|
169 |
-
k: v
|
170 |
-
for k, v in normal_dict.items()
|
171 |
-
if k in model_dict and v.shape == model_dict[k].shape
|
172 |
-
}
|
173 |
-
print(colored(f"Resume normal model from {normal_path}", "green"))
|
174 |
-
|
175 |
-
model_dict.update(main_dict)
|
176 |
-
model_dict.update(normal_dict)
|
177 |
-
model.load_state_dict(model_dict)
|
178 |
-
|
179 |
-
# clean unused GPU memory
|
180 |
-
del main_dict
|
181 |
-
del normal_dict
|
182 |
-
del model_dict
|
183 |
-
torch.cuda.empty_cache()
|
184 |
-
|
185 |
-
|
186 |
-
def reshape_sample_tensor(sample_tensor, num_views):
|
187 |
-
if num_views == 1:
|
188 |
-
return sample_tensor
|
189 |
-
# Need to repeat sample_tensor along the batch dim num_views times
|
190 |
-
sample_tensor = sample_tensor.unsqueeze(dim=1)
|
191 |
-
sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
|
192 |
-
sample_tensor = sample_tensor.view(
|
193 |
-
sample_tensor.shape[0] * sample_tensor.shape[1],
|
194 |
-
sample_tensor.shape[2],
|
195 |
-
sample_tensor.shape[3],
|
196 |
-
)
|
197 |
-
return sample_tensor
|
198 |
-
|
199 |
-
|
200 |
-
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
|
201 |
-
"""Sets the learning rate to the initial LR decayed by schedule"""
|
202 |
-
if epoch in schedule:
|
203 |
-
lr *= gamma
|
204 |
-
for param_group in optimizer.param_groups:
|
205 |
-
param_group["lr"] = lr
|
206 |
-
return lr
|
207 |
-
|
208 |
-
|
209 |
-
def compute_acc(pred, gt, thresh=0.5):
|
210 |
-
"""
|
211 |
-
return:
|
212 |
-
IOU, precision, and recall
|
213 |
-
"""
|
214 |
-
with torch.no_grad():
|
215 |
-
vol_pred = pred > thresh
|
216 |
-
vol_gt = gt > thresh
|
217 |
-
|
218 |
-
union = vol_pred | vol_gt
|
219 |
-
inter = vol_pred & vol_gt
|
220 |
-
|
221 |
-
true_pos = inter.sum().float()
|
222 |
-
|
223 |
-
union = union.sum().float()
|
224 |
-
if union == 0:
|
225 |
-
union = 1
|
226 |
-
vol_pred = vol_pred.sum().float()
|
227 |
-
if vol_pred == 0:
|
228 |
-
vol_pred = 1
|
229 |
-
vol_gt = vol_gt.sum().float()
|
230 |
-
if vol_gt == 0:
|
231 |
-
vol_gt = 1
|
232 |
-
return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
233 |
-
|
234 |
-
def calc_error(opt, net, cuda, dataset, num_tests):
|
235 |
-
if num_tests > len(dataset):
|
236 |
-
num_tests = len(dataset)
|
237 |
-
with torch.no_grad():
|
238 |
-
erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
|
239 |
-
for idx in tqdm(range(num_tests)):
|
240 |
-
data = dataset[idx * len(dataset) // num_tests]
|
241 |
-
# retrieve the data
|
242 |
-
image_tensor = data["img"].to(device=cuda)
|
243 |
-
calib_tensor = data["calib"].to(device=cuda)
|
244 |
-
sample_tensor = data["samples"].to(device=cuda).unsqueeze(0)
|
245 |
-
if opt.num_views > 1:
|
246 |
-
sample_tensor = reshape_sample_tensor(sample_tensor,
|
247 |
-
opt.num_views)
|
248 |
-
label_tensor = data["labels"].to(device=cuda).unsqueeze(0)
|
249 |
-
|
250 |
-
res, error = net.forward(image_tensor,
|
251 |
-
sample_tensor,
|
252 |
-
calib_tensor,
|
253 |
-
labels=label_tensor)
|
254 |
-
|
255 |
-
IOU, prec, recall = compute_acc(res, label_tensor)
|
256 |
-
|
257 |
-
# print(
|
258 |
-
# '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
|
259 |
-
# .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
|
260 |
-
erorr_arr.append(error.item())
|
261 |
-
IOU_arr.append(IOU.item())
|
262 |
-
prec_arr.append(prec.item())
|
263 |
-
recall_arr.append(recall.item())
|
264 |
-
|
265 |
-
return (
|
266 |
-
np.average(erorr_arr),
|
267 |
-
np.average(IOU_arr),
|
268 |
-
np.average(prec_arr),
|
269 |
-
np.average(recall_arr),
|
270 |
-
)
|
271 |
-
|
272 |
-
|
273 |
-
def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
|
274 |
-
if num_tests > len(dataset):
|
275 |
-
num_tests = len(dataset)
|
276 |
-
with torch.no_grad():
|
277 |
-
error_color_arr = []
|
278 |
-
|
279 |
-
for idx in tqdm(range(num_tests)):
|
280 |
-
data = dataset[idx * len(dataset) // num_tests]
|
281 |
-
# retrieve the data
|
282 |
-
image_tensor = data["img"].to(device=cuda)
|
283 |
-
calib_tensor = data["calib"].to(device=cuda)
|
284 |
-
color_sample_tensor = data["color_samples"].to(
|
285 |
-
device=cuda).unsqueeze(0)
|
286 |
-
|
287 |
-
if opt.num_views > 1:
|
288 |
-
color_sample_tensor = reshape_sample_tensor(
|
289 |
-
color_sample_tensor, opt.num_views)
|
290 |
-
|
291 |
-
rgb_tensor = data["rgbs"].to(device=cuda).unsqueeze(0)
|
292 |
-
|
293 |
-
netG.filter(image_tensor)
|
294 |
-
_, errorC = netC.forward(
|
295 |
-
image_tensor,
|
296 |
-
netG.get_im_feat(),
|
297 |
-
color_sample_tensor,
|
298 |
-
calib_tensor,
|
299 |
-
labels=rgb_tensor,
|
300 |
-
)
|
301 |
-
|
302 |
-
# print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
|
303 |
-
# .format(idx, num_tests, errorG.item(), errorC.item()))
|
304 |
-
error_color_arr.append(errorC.item())
|
305 |
-
|
306 |
-
return np.average(error_color_arr)
|
307 |
-
|
308 |
-
|
309 |
-
# pytorch lightning training related fucntions
|
310 |
-
|
311 |
-
|
312 |
def query_func(opt, netG, features, points, proj_matrix=None):
|
313 |
"""
|
314 |
- points: size of (bz, N, 3)
|
@@ -317,7 +107,7 @@ def query_func(opt, netG, features, points, proj_matrix=None):
|
|
317 |
"""
|
318 |
assert len(points) == 1
|
319 |
samples = points.repeat(opt.num_views, 1, 1)
|
320 |
-
samples = samples.permute(0, 2, 1)
|
321 |
|
322 |
# view specific query
|
323 |
if proj_matrix is not None:
|
@@ -337,85 +127,25 @@ def query_func(opt, netG, features, points, proj_matrix=None):
|
|
337 |
|
338 |
return preds
|
339 |
|
|
|
340 |
def query_func_IF(batch, netG, points):
|
341 |
"""
|
342 |
- points: size of (bz, N, 3)
|
343 |
return: size of (bz, 1, N)
|
344 |
"""
|
345 |
-
|
346 |
batch["samples_geo"] = points
|
347 |
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
|
348 |
-
|
349 |
preds = netG(batch)
|
350 |
|
351 |
return preds.unsqueeze(1)
|
352 |
|
353 |
|
354 |
-
def isin(ar1, ar2):
|
355 |
-
return (ar1[..., None] == ar2).any(-1)
|
356 |
-
|
357 |
-
|
358 |
-
def in1d(ar1, ar2):
|
359 |
-
mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
|
360 |
-
mask[ar2.unique()] = True
|
361 |
-
return mask[ar1]
|
362 |
-
|
363 |
def batch_mean(res, key):
|
364 |
-
return torch.stack(
|
365 |
-
x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key])
|
366 |
-
|
367 |
-
]).mean()
|
368 |
-
|
369 |
-
|
370 |
-
def tf_log_convert(log_dict):
|
371 |
-
new_log_dict = log_dict.copy()
|
372 |
-
for k, v in log_dict.items():
|
373 |
-
new_log_dict[k.replace("_", "/")] = v
|
374 |
-
del new_log_dict[k]
|
375 |
-
|
376 |
-
return new_log_dict
|
377 |
-
|
378 |
-
|
379 |
-
def bar_log_convert(log_dict, name=None, rot=None):
|
380 |
-
from decimal import Decimal
|
381 |
-
|
382 |
-
new_log_dict = {}
|
383 |
-
|
384 |
-
if name is not None:
|
385 |
-
new_log_dict["name"] = name[0]
|
386 |
-
if rot is not None:
|
387 |
-
new_log_dict["rot"] = rot[0]
|
388 |
-
|
389 |
-
for k, v in log_dict.items():
|
390 |
-
color = "yellow"
|
391 |
-
if "loss" in k:
|
392 |
-
color = "red"
|
393 |
-
k = k.replace("loss", "L")
|
394 |
-
elif "acc" in k:
|
395 |
-
color = "green"
|
396 |
-
k = k.replace("acc", "A")
|
397 |
-
elif "iou" in k:
|
398 |
-
color = "green"
|
399 |
-
k = k.replace("iou", "I")
|
400 |
-
elif "prec" in k:
|
401 |
-
color = "green"
|
402 |
-
k = k.replace("prec", "P")
|
403 |
-
elif "recall" in k:
|
404 |
-
color = "green"
|
405 |
-
k = k.replace("recall", "R")
|
406 |
-
|
407 |
-
if "lr" not in k:
|
408 |
-
new_log_dict[colored(k.split("_")[1],
|
409 |
-
color)] = colored(f"{v:.3f}", color)
|
410 |
-
else:
|
411 |
-
new_log_dict[colored(k.split("_")[1],
|
412 |
-
color)] = colored(f"{Decimal(str(v)):.1E}",
|
413 |
-
color)
|
414 |
-
|
415 |
-
if "loss" in new_log_dict.keys():
|
416 |
-
del new_log_dict["loss"]
|
417 |
-
|
418 |
-
return new_log_dict
|
419 |
|
420 |
|
421 |
def accumulate(outputs, rot_num, split):
|
@@ -430,160 +160,10 @@ def accumulate(outputs, rot_num, split):
|
|
430 |
keyword = f"{dataset}/{metric}"
|
431 |
if keyword not in hparam_log_dict.keys():
|
432 |
hparam_log_dict[keyword] = 0
|
433 |
-
for idx in range(split[dataset][0] * rot_num,
|
434 |
-
split[dataset][1] * rot_num):
|
435 |
hparam_log_dict[keyword] += outputs[idx][metric].item()
|
436 |
-
hparam_log_dict[keyword] /= (split[dataset][1] -
|
437 |
-
split[dataset][0]) * rot_num
|
438 |
|
439 |
print(colored(hparam_log_dict, "green"))
|
440 |
|
441 |
return hparam_log_dict
|
442 |
-
|
443 |
-
|
444 |
-
def calc_error_N(outputs, targets):
|
445 |
-
"""calculate the error of normal (IGR)
|
446 |
-
|
447 |
-
Args:
|
448 |
-
outputs (torch.tensor): [B, 3, N]
|
449 |
-
target (torch.tensor): [B, N, 3]
|
450 |
-
|
451 |
-
# manifold loss and grad_loss in IGR paper
|
452 |
-
grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
|
453 |
-
normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
|
454 |
-
|
455 |
-
Returns:
|
456 |
-
torch.tensor: error of valid normals on the surface
|
457 |
-
"""
|
458 |
-
# outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
|
459 |
-
outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
|
460 |
-
targets = targets.reshape(-1, 3)[:, 2:3]
|
461 |
-
with_normals = targets.sum(dim=1).abs() > 0.0
|
462 |
-
|
463 |
-
# eikonal loss
|
464 |
-
grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
|
465 |
-
# normals loss
|
466 |
-
normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
|
467 |
-
|
468 |
-
return grad_loss * 0.0 + normal_loss
|
469 |
-
|
470 |
-
|
471 |
-
def calc_knn_acc(preds, carn_verts, labels, pick_num):
|
472 |
-
"""calculate knn accuracy
|
473 |
-
|
474 |
-
Args:
|
475 |
-
preds (torch.tensor): [B, 3, N]
|
476 |
-
carn_verts (torch.tensor): [SMPLX_V_num, 3]
|
477 |
-
labels (torch.tensor): [B, N_knn, N]
|
478 |
-
"""
|
479 |
-
N_knn_full = labels.shape[1]
|
480 |
-
preds = preds.permute(0, 2, 1).reshape(-1, 3)
|
481 |
-
labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
|
482 |
-
labels = labels[:, :pick_num]
|
483 |
-
|
484 |
-
dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
|
485 |
-
knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
|
486 |
-
cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
|
487 |
-
bool_col = torch.zeros_like(cat_mat)[:, 0]
|
488 |
-
for i in range(pick_num * 2 - 1):
|
489 |
-
bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
|
490 |
-
acc = (bool_col > 0).sum() / len(bool_col)
|
491 |
-
|
492 |
-
return acc
|
493 |
-
|
494 |
-
|
495 |
-
def calc_acc_seg(output, target, num_multiseg):
|
496 |
-
from pytorch_lightning.metrics import Accuracy
|
497 |
-
|
498 |
-
return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
|
499 |
-
target.flatten().cpu())
|
500 |
-
|
501 |
-
|
502 |
-
def add_watermark(imgs, titles):
|
503 |
-
|
504 |
-
# Write some Text
|
505 |
-
|
506 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
507 |
-
bottomLeftCornerOfText = (350, 50)
|
508 |
-
bottomRightCornerOfText = (800, 50)
|
509 |
-
fontScale = 1
|
510 |
-
fontColor = (1.0, 1.0, 1.0)
|
511 |
-
lineType = 2
|
512 |
-
|
513 |
-
for i in range(len(imgs)):
|
514 |
-
|
515 |
-
title = titles[i + 1]
|
516 |
-
cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
|
517 |
-
fontColor, lineType)
|
518 |
-
|
519 |
-
if i == 0:
|
520 |
-
cv2.putText(
|
521 |
-
imgs[i],
|
522 |
-
str(titles[i][0]),
|
523 |
-
bottomRightCornerOfText,
|
524 |
-
font,
|
525 |
-
fontScale,
|
526 |
-
fontColor,
|
527 |
-
lineType,
|
528 |
-
)
|
529 |
-
|
530 |
-
result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
|
531 |
-
|
532 |
-
return result
|
533 |
-
|
534 |
-
|
535 |
-
def make_test_gif(img_dir):
|
536 |
-
|
537 |
-
if img_dir is not None and len(os.listdir(img_dir)) > 0:
|
538 |
-
for dataset in os.listdir(img_dir):
|
539 |
-
for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
|
540 |
-
img_lst = []
|
541 |
-
im1 = None
|
542 |
-
for file in sorted(
|
543 |
-
os.listdir(osp.join(img_dir, dataset, subject))):
|
544 |
-
if file[-3:] not in ["obj", "gif"]:
|
545 |
-
img_path = os.path.join(img_dir, dataset, subject,
|
546 |
-
file)
|
547 |
-
if im1 == None:
|
548 |
-
im1 = PIL.Image.open(img_path)
|
549 |
-
else:
|
550 |
-
img_lst.append(PIL.Image.open(img_path))
|
551 |
-
|
552 |
-
print(os.path.join(img_dir, dataset, subject, "out.gif"))
|
553 |
-
im1.save(
|
554 |
-
os.path.join(img_dir, dataset, subject, "out.gif"),
|
555 |
-
save_all=True,
|
556 |
-
append_images=img_lst,
|
557 |
-
duration=500,
|
558 |
-
loop=0,
|
559 |
-
)
|
560 |
-
|
561 |
-
|
562 |
-
def export_cfg(logger, dir, cfg):
|
563 |
-
|
564 |
-
cfg_export_file = osp.join(dir, f"cfg_{logger.version}.yaml")
|
565 |
-
|
566 |
-
if not osp.exists(cfg_export_file):
|
567 |
-
os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
|
568 |
-
with open(cfg_export_file, "w+") as file:
|
569 |
-
_ = yaml.dump(cfg, file)
|
570 |
-
|
571 |
-
|
572 |
-
from yacs.config import CfgNode
|
573 |
-
|
574 |
-
_VALID_TYPES = {tuple, list, str, int, float, bool}
|
575 |
-
|
576 |
-
|
577 |
-
def convert_to_dict(cfg_node, key_list=[]):
|
578 |
-
""" Convert a config node to dictionary """
|
579 |
-
if not isinstance(cfg_node, CfgNode):
|
580 |
-
if type(cfg_node) not in _VALID_TYPES:
|
581 |
-
print(
|
582 |
-
"Key {} with value {} is not a valid type; valid types: {}".
|
583 |
-
format(".".join(key_list), type(cfg_node), _VALID_TYPES), )
|
584 |
-
return cfg_node
|
585 |
-
else:
|
586 |
-
cfg_dict = dict(cfg_node)
|
587 |
-
for k, v in cfg_dict.items():
|
588 |
-
cfg_dict[k] = convert_to_dict(v, key_list + [k])
|
589 |
-
return cfg_dict
|
|
|
14 |
#
|
15 |
# Contact: [email protected]
|
16 |
|
|
|
|
|
17 |
import torch
|
|
|
18 |
from ..dataset.mesh_util import *
|
19 |
from ..net.geometry import orthogonal
|
|
|
|
|
|
|
20 |
from termcolor import colored
|
21 |
import pytorch_lightning as pl
|
22 |
|
23 |
|
24 |
+
class Format:
|
25 |
+
end = '\033[0m'
|
26 |
+
start = '\033[4m'
|
27 |
+
|
28 |
+
|
29 |
def init_loss():
|
30 |
|
31 |
losses = {
|
32 |
+
# Cloth: chamfer distance
|
33 |
"cloth": {
|
34 |
"weight": 1e3,
|
35 |
"value": 0.0
|
36 |
},
|
37 |
+
# Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
38 |
+
"stiff": {
|
39 |
"weight": 1e5,
|
40 |
"value": 0.0
|
41 |
},
|
42 |
+
# Cloth: det(R) = 1
|
43 |
"rigid": {
|
44 |
"weight": 1e5,
|
45 |
"value": 0.0
|
46 |
},
|
47 |
+
# Cloth: edge length
|
48 |
"edge": {
|
49 |
"weight": 0,
|
50 |
"value": 0.0
|
51 |
},
|
52 |
+
# Cloth: normal consistency
|
53 |
"nc": {
|
54 |
"weight": 0,
|
55 |
"value": 0.0
|
56 |
},
|
57 |
+
# Cloth: laplacian smoonth
|
58 |
+
"lapla": {
|
59 |
"weight": 1e2,
|
60 |
"value": 0.0
|
61 |
},
|
62 |
+
# Body: Normal_pred - Normal_smpl
|
63 |
"normal": {
|
64 |
"weight": 1e0,
|
65 |
"value": 0.0
|
66 |
},
|
67 |
+
# Body: Silhouette_pred - Silhouette_smpl
|
68 |
"silhouette": {
|
69 |
"weight": 1e0,
|
70 |
"value": 0.0
|
71 |
},
|
72 |
+
# Joint: reprojected joints difference
|
73 |
"joint": {
|
74 |
"weight": 5e0,
|
75 |
"value": 0.0
|
|
|
80 |
|
81 |
|
82 |
class SubTrainer(pl.Trainer):
|
|
|
83 |
def save_checkpoint(self, filepath, weights_only=False):
|
84 |
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
85 |
Args:
|
|
|
99 |
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
|
100 |
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def query_func(opt, netG, features, points, proj_matrix=None):
|
103 |
"""
|
104 |
- points: size of (bz, N, 3)
|
|
|
107 |
"""
|
108 |
assert len(points) == 1
|
109 |
samples = points.repeat(opt.num_views, 1, 1)
|
110 |
+
samples = samples.permute(0, 2, 1) # [bz, 3, N]
|
111 |
|
112 |
# view specific query
|
113 |
if proj_matrix is not None:
|
|
|
127 |
|
128 |
return preds
|
129 |
|
130 |
+
|
131 |
def query_func_IF(batch, netG, points):
|
132 |
"""
|
133 |
- points: size of (bz, N, 3)
|
134 |
return: size of (bz, 1, N)
|
135 |
"""
|
136 |
+
|
137 |
batch["samples_geo"] = points
|
138 |
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
|
139 |
+
|
140 |
preds = netG(batch)
|
141 |
|
142 |
return preds.unsqueeze(1)
|
143 |
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def batch_mean(res, key):
|
146 |
+
return torch.stack(
|
147 |
+
[x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res]
|
148 |
+
).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
|
151 |
def accumulate(outputs, rot_num, split):
|
|
|
160 |
keyword = f"{dataset}/{metric}"
|
161 |
if keyword not in hparam_log_dict.keys():
|
162 |
hparam_log_dict[keyword] = 0
|
163 |
+
for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num):
|
|
|
164 |
hparam_log_dict[keyword] += outputs[idx][metric].item()
|
165 |
+
hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num
|
|
|
166 |
|
167 |
print(colored(hparam_log_dict, "green"))
|
168 |
|
169 |
return hparam_log_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/common/voxelize.py
CHANGED
@@ -13,6 +13,7 @@ from lib.common.libmesh.inside_mesh import check_mesh_contains
|
|
13 |
|
14 |
# From Occupancy Networks, Mescheder et. al. CVPR'19
|
15 |
|
|
|
16 |
def make_3d_grid(bb_min, bb_max, shape):
|
17 |
''' Makes a 3D grid.
|
18 |
|
@@ -37,7 +38,7 @@ def make_3d_grid(bb_min, bb_max, shape):
|
|
37 |
|
38 |
class VoxelGrid:
|
39 |
def __init__(self, data, loc=(0., 0., 0.), scale=1):
|
40 |
-
assert(data.shape[0] == data.shape[1] == data.shape[2])
|
41 |
data = np.asarray(data, dtype=np.bool)
|
42 |
loc = np.asarray(loc)
|
43 |
self.data = data
|
@@ -53,7 +54,7 @@ class VoxelGrid:
|
|
53 |
|
54 |
# Default scale, scales the mesh to [-0.45, 0.45]^3
|
55 |
if scale is None:
|
56 |
-
scale = (bounds[1] - bounds[0]).max()/0.9
|
57 |
|
58 |
loc = np.asarray(loc)
|
59 |
scale = float(scale)
|
@@ -61,7 +62,7 @@ class VoxelGrid:
|
|
61 |
# Transform mesh
|
62 |
mesh = mesh.copy()
|
63 |
mesh.apply_translation(-loc)
|
64 |
-
mesh.apply_scale(1/scale)
|
65 |
|
66 |
# Apply method
|
67 |
if method == 'ray':
|
@@ -75,7 +76,7 @@ class VoxelGrid:
|
|
75 |
def down_sample(self, factor=2):
|
76 |
if not (self.resolution % factor) == 0:
|
77 |
raise ValueError('Resolution must be divisible by factor.')
|
78 |
-
new_data = block_reduce(self.data, (factor,) * 3, np.max)
|
79 |
return VoxelGrid(new_data, self.loc, self.scale)
|
80 |
|
81 |
def to_mesh(self):
|
@@ -103,9 +104,9 @@ class VoxelGrid:
|
|
103 |
f2 = f2_r | f2_l
|
104 |
f3 = f3_r | f3_l
|
105 |
|
106 |
-
assert(f1.shape == (nx + 1, ny, nz))
|
107 |
-
assert(f2.shape == (nx, ny + 1, nz))
|
108 |
-
assert(f3.shape == (nx, ny, nz + 1))
|
109 |
|
110 |
# Determine if vertex present
|
111 |
v = np.full(grid_shape, False)
|
@@ -146,53 +147,76 @@ class VoxelGrid:
|
|
146 |
f2_r_x, f2_r_y, f2_r_z = np.where(f2_r)
|
147 |
f3_r_x, f3_r_y, f3_r_z = np.where(f3_r)
|
148 |
|
149 |
-
faces_1_l = np.stack(
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
vertices = self.loc + self.scale * vertices
|
198 |
mesh = trimesh.Trimesh(vertices, faces, process=False)
|
@@ -200,7 +224,7 @@ class VoxelGrid:
|
|
200 |
|
201 |
@property
|
202 |
def resolution(self):
|
203 |
-
assert(self.data.shape[0] == self.data.shape[1] == self.data.shape[2])
|
204 |
return self.data.shape[0]
|
205 |
|
206 |
def contains(self, points):
|
@@ -211,12 +235,9 @@ class VoxelGrid:
|
|
211 |
# Discretize points to [0, nx-1]^3
|
212 |
points_i = ((points + 0.5) * nx).astype(np.int32)
|
213 |
# i1, i2, i3 have sizes (batch_size, T)
|
214 |
-
i1, i2, i3 = points_i[..., 0],
|
215 |
# Only use indices inside bounding box
|
216 |
-
mask = (
|
217 |
-
(i1 >= 0) & (i2 >= 0) & (i3 >= 0)
|
218 |
-
& (nx > i1) & (nx > i2) & (nx > i3)
|
219 |
-
)
|
220 |
# Prevent out of bounds error
|
221 |
i1 = i1[mask]
|
222 |
i2 = i2[mask]
|
@@ -254,7 +275,7 @@ def voxelize_surface(mesh, resolution):
|
|
254 |
vertices = (vertices + 0.5) * resolution
|
255 |
|
256 |
face_loc = vertices[faces]
|
257 |
-
occ = np.full((resolution,) * 3, 0, dtype=np.int32)
|
258 |
face_loc = face_loc.astype(np.float32)
|
259 |
|
260 |
voxelize_mesh_(occ, face_loc)
|
@@ -264,9 +285,9 @@ def voxelize_surface(mesh, resolution):
|
|
264 |
|
265 |
|
266 |
def voxelize_interior(mesh, resolution):
|
267 |
-
shape = (resolution,) * 3
|
268 |
-
bb_min = (0.5,) * 3
|
269 |
-
bb_max = (resolution - 0.5,) * 3
|
270 |
# Create points. Add noise to break symmetry
|
271 |
points = make_3d_grid(bb_min, bb_max, shape=shape).numpy()
|
272 |
points = points + 0.1 * (np.random.rand(*points.shape) - 0.5)
|
@@ -280,14 +301,9 @@ def check_voxel_occupied(occupancy_grid):
|
|
280 |
occ = occupancy_grid
|
281 |
|
282 |
occupied = (
|
283 |
-
occ[..., :-1, :-1, :-1]
|
284 |
-
& occ[..., :-1, :-1, 1:]
|
285 |
-
& occ[...,
|
286 |
-
& occ[..., :-1, 1:, 1:]
|
287 |
-
& occ[..., 1:, :-1, :-1]
|
288 |
-
& occ[..., 1:, :-1, 1:]
|
289 |
-
& occ[..., 1:, 1:, :-1]
|
290 |
-
& occ[..., 1:, 1:, 1:]
|
291 |
)
|
292 |
return occupied
|
293 |
|
@@ -296,14 +312,9 @@ def check_voxel_unoccupied(occupancy_grid):
|
|
296 |
occ = occupancy_grid
|
297 |
|
298 |
unoccupied = ~(
|
299 |
-
occ[..., :-1, :-1, :-1]
|
300 |
-
| occ[..., :-1, :-1, 1:]
|
301 |
-
| occ[...,
|
302 |
-
| occ[..., :-1, 1:, 1:]
|
303 |
-
| occ[..., 1:, :-1, :-1]
|
304 |
-
| occ[..., 1:, :-1, 1:]
|
305 |
-
| occ[..., 1:, 1:, :-1]
|
306 |
-
| occ[..., 1:, 1:, 1:]
|
307 |
)
|
308 |
return unoccupied
|
309 |
|
|
|
13 |
|
14 |
# From Occupancy Networks, Mescheder et. al. CVPR'19
|
15 |
|
16 |
+
|
17 |
def make_3d_grid(bb_min, bb_max, shape):
|
18 |
''' Makes a 3D grid.
|
19 |
|
|
|
38 |
|
39 |
class VoxelGrid:
|
40 |
def __init__(self, data, loc=(0., 0., 0.), scale=1):
|
41 |
+
assert (data.shape[0] == data.shape[1] == data.shape[2])
|
42 |
data = np.asarray(data, dtype=np.bool)
|
43 |
loc = np.asarray(loc)
|
44 |
self.data = data
|
|
|
54 |
|
55 |
# Default scale, scales the mesh to [-0.45, 0.45]^3
|
56 |
if scale is None:
|
57 |
+
scale = (bounds[1] - bounds[0]).max() / 0.9
|
58 |
|
59 |
loc = np.asarray(loc)
|
60 |
scale = float(scale)
|
|
|
62 |
# Transform mesh
|
63 |
mesh = mesh.copy()
|
64 |
mesh.apply_translation(-loc)
|
65 |
+
mesh.apply_scale(1 / scale)
|
66 |
|
67 |
# Apply method
|
68 |
if method == 'ray':
|
|
|
76 |
def down_sample(self, factor=2):
|
77 |
if not (self.resolution % factor) == 0:
|
78 |
raise ValueError('Resolution must be divisible by factor.')
|
79 |
+
new_data = block_reduce(self.data, (factor, ) * 3, np.max)
|
80 |
return VoxelGrid(new_data, self.loc, self.scale)
|
81 |
|
82 |
def to_mesh(self):
|
|
|
104 |
f2 = f2_r | f2_l
|
105 |
f3 = f3_r | f3_l
|
106 |
|
107 |
+
assert (f1.shape == (nx + 1, ny, nz))
|
108 |
+
assert (f2.shape == (nx, ny + 1, nz))
|
109 |
+
assert (f3.shape == (nx, ny, nz + 1))
|
110 |
|
111 |
# Determine if vertex present
|
112 |
v = np.full(grid_shape, False)
|
|
|
147 |
f2_r_x, f2_r_y, f2_r_z = np.where(f2_r)
|
148 |
f3_r_x, f3_r_y, f3_r_z = np.where(f3_r)
|
149 |
|
150 |
+
faces_1_l = np.stack(
|
151 |
+
[
|
152 |
+
v_idx[f1_l_x, f1_l_y, f1_l_z],
|
153 |
+
v_idx[f1_l_x, f1_l_y, f1_l_z + 1],
|
154 |
+
v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1],
|
155 |
+
v_idx[f1_l_x, f1_l_y + 1, f1_l_z],
|
156 |
+
],
|
157 |
+
axis=1
|
158 |
+
)
|
159 |
+
|
160 |
+
faces_1_r = np.stack(
|
161 |
+
[
|
162 |
+
v_idx[f1_r_x, f1_r_y, f1_r_z],
|
163 |
+
v_idx[f1_r_x, f1_r_y + 1, f1_r_z],
|
164 |
+
v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1],
|
165 |
+
v_idx[f1_r_x, f1_r_y, f1_r_z + 1],
|
166 |
+
],
|
167 |
+
axis=1
|
168 |
+
)
|
169 |
+
|
170 |
+
faces_2_l = np.stack(
|
171 |
+
[
|
172 |
+
v_idx[f2_l_x, f2_l_y, f2_l_z],
|
173 |
+
v_idx[f2_l_x + 1, f2_l_y, f2_l_z],
|
174 |
+
v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1],
|
175 |
+
v_idx[f2_l_x, f2_l_y, f2_l_z + 1],
|
176 |
+
],
|
177 |
+
axis=1
|
178 |
+
)
|
179 |
+
|
180 |
+
faces_2_r = np.stack(
|
181 |
+
[
|
182 |
+
v_idx[f2_r_x, f2_r_y, f2_r_z],
|
183 |
+
v_idx[f2_r_x, f2_r_y, f2_r_z + 1],
|
184 |
+
v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1],
|
185 |
+
v_idx[f2_r_x + 1, f2_r_y, f2_r_z],
|
186 |
+
],
|
187 |
+
axis=1
|
188 |
+
)
|
189 |
+
|
190 |
+
faces_3_l = np.stack(
|
191 |
+
[
|
192 |
+
v_idx[f3_l_x, f3_l_y, f3_l_z],
|
193 |
+
v_idx[f3_l_x, f3_l_y + 1, f3_l_z],
|
194 |
+
v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z],
|
195 |
+
v_idx[f3_l_x + 1, f3_l_y, f3_l_z],
|
196 |
+
],
|
197 |
+
axis=1
|
198 |
+
)
|
199 |
+
|
200 |
+
faces_3_r = np.stack(
|
201 |
+
[
|
202 |
+
v_idx[f3_r_x, f3_r_y, f3_r_z],
|
203 |
+
v_idx[f3_r_x + 1, f3_r_y, f3_r_z],
|
204 |
+
v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z],
|
205 |
+
v_idx[f3_r_x, f3_r_y + 1, f3_r_z],
|
206 |
+
],
|
207 |
+
axis=1
|
208 |
+
)
|
209 |
+
|
210 |
+
faces = np.concatenate(
|
211 |
+
[
|
212 |
+
faces_1_l,
|
213 |
+
faces_1_r,
|
214 |
+
faces_2_l,
|
215 |
+
faces_2_r,
|
216 |
+
faces_3_l,
|
217 |
+
faces_3_r,
|
218 |
+
], axis=0
|
219 |
+
)
|
220 |
|
221 |
vertices = self.loc + self.scale * vertices
|
222 |
mesh = trimesh.Trimesh(vertices, faces, process=False)
|
|
|
224 |
|
225 |
@property
|
226 |
def resolution(self):
|
227 |
+
assert (self.data.shape[0] == self.data.shape[1] == self.data.shape[2])
|
228 |
return self.data.shape[0]
|
229 |
|
230 |
def contains(self, points):
|
|
|
235 |
# Discretize points to [0, nx-1]^3
|
236 |
points_i = ((points + 0.5) * nx).astype(np.int32)
|
237 |
# i1, i2, i3 have sizes (batch_size, T)
|
238 |
+
i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2]
|
239 |
# Only use indices inside bounding box
|
240 |
+
mask = ((i1 >= 0) & (i2 >= 0) & (i3 >= 0) & (nx > i1) & (nx > i2) & (nx > i3))
|
|
|
|
|
|
|
241 |
# Prevent out of bounds error
|
242 |
i1 = i1[mask]
|
243 |
i2 = i2[mask]
|
|
|
275 |
vertices = (vertices + 0.5) * resolution
|
276 |
|
277 |
face_loc = vertices[faces]
|
278 |
+
occ = np.full((resolution, ) * 3, 0, dtype=np.int32)
|
279 |
face_loc = face_loc.astype(np.float32)
|
280 |
|
281 |
voxelize_mesh_(occ, face_loc)
|
|
|
285 |
|
286 |
|
287 |
def voxelize_interior(mesh, resolution):
|
288 |
+
shape = (resolution, ) * 3
|
289 |
+
bb_min = (0.5, ) * 3
|
290 |
+
bb_max = (resolution - 0.5, ) * 3
|
291 |
# Create points. Add noise to break symmetry
|
292 |
points = make_3d_grid(bb_min, bb_max, shape=shape).numpy()
|
293 |
points = points + 0.1 * (np.random.rand(*points.shape) - 0.5)
|
|
|
301 |
occ = occupancy_grid
|
302 |
|
303 |
occupied = (
|
304 |
+
occ[..., :-1, :-1, :-1] & occ[..., :-1, :-1, 1:] & occ[..., :-1, 1:, :-1] &
|
305 |
+
occ[..., :-1, 1:, 1:] & occ[..., 1:, :-1, :-1] & occ[..., 1:, :-1, 1:] &
|
306 |
+
occ[..., 1:, 1:, :-1] & occ[..., 1:, 1:, 1:]
|
|
|
|
|
|
|
|
|
|
|
307 |
)
|
308 |
return occupied
|
309 |
|
|
|
312 |
occ = occupancy_grid
|
313 |
|
314 |
unoccupied = ~(
|
315 |
+
occ[..., :-1, :-1, :-1] | occ[..., :-1, :-1, 1:] | occ[..., :-1, 1:, :-1] |
|
316 |
+
occ[..., :-1, 1:, 1:] | occ[..., 1:, :-1, :-1] | occ[..., 1:, :-1, 1:] |
|
317 |
+
occ[..., 1:, 1:, :-1] | occ[..., 1:, 1:, 1:]
|
|
|
|
|
|
|
|
|
|
|
318 |
)
|
319 |
return unoccupied
|
320 |
|
lib/dataset/Evaluator.py
CHANGED
@@ -37,7 +37,6 @@ class _PointFaceDistance(Function):
|
|
37 |
"""
|
38 |
Torch autograd Function wrapper PointFaceDistance Cuda implementation
|
39 |
"""
|
40 |
-
|
41 |
@staticmethod
|
42 |
def forward(
|
43 |
ctx,
|
@@ -92,12 +91,15 @@ class _PointFaceDistance(Function):
|
|
92 |
grad_dists = grad_dists.contiguous()
|
93 |
points, tris, idxs = ctx.saved_tensors
|
94 |
min_triangle_area = ctx.min_triangle_area
|
95 |
-
grad_points, grad_tris = _C.point_face_dist_backward(
|
|
|
|
|
96 |
return grad_points, None, grad_tris, None, None, None
|
97 |
|
98 |
|
99 |
-
def _rand_barycentric_coords(
|
100 |
-
|
|
|
101 |
"""
|
102 |
Helper function to generate random barycentric coordinates which are uniformly
|
103 |
distributed over a triangle.
|
@@ -167,19 +169,21 @@ def sample_points_from_meshes(meshes, num_samples: int = 10000):
|
|
167 |
faces = meshes.faces_packed()
|
168 |
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
|
169 |
num_meshes = len(meshes)
|
170 |
-
num_valid_meshes = torch.sum(meshes.valid)
|
171 |
|
172 |
# Initialize samples tensor with fill value 0 for empty meshes.
|
173 |
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
|
174 |
|
175 |
# Only compute samples for non empty meshes
|
176 |
with torch.no_grad():
|
177 |
-
areas, _ = mesh_face_areas_normals(verts, faces)
|
178 |
max_faces = meshes.num_faces_per_mesh().max().item()
|
179 |
-
areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces)
|
180 |
|
181 |
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
|
182 |
-
samples_face_idxs = areas_padded.multinomial(
|
|
|
|
|
183 |
samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
|
184 |
|
185 |
# Randomly generate barycentric coords.
|
@@ -200,23 +204,25 @@ def point_mesh_distance(meshes, pcls, weighted=True):
|
|
200 |
raise ValueError("meshes and pointclouds must be equal sized batches")
|
201 |
|
202 |
# packed representation for pointclouds
|
203 |
-
points = pcls.points_packed()
|
204 |
points_first_idx = pcls.cloud_to_packed_first_idx()
|
205 |
max_points = pcls.num_points_per_cloud().max().item()
|
206 |
|
207 |
# packed representation for faces
|
208 |
verts_packed = meshes.verts_packed()
|
209 |
faces_packed = meshes.faces_packed()
|
210 |
-
tris = verts_packed[faces_packed]
|
211 |
tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
212 |
|
213 |
# point to face distance: shape (P,)
|
214 |
-
point_to_face, idxs = _PointFaceDistance.apply(
|
|
|
|
|
215 |
|
216 |
if weighted:
|
217 |
# weight each example by the inverse of number of points in the example
|
218 |
-
point_to_cloud_idx = pcls.packed_to_cloud_idx()
|
219 |
-
num_points_per_cloud = pcls.num_points_per_cloud()
|
220 |
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
221 |
weights_p = 1.0 / weights_p.float()
|
222 |
point_to_face = torch.sqrt(point_to_face) * weights_p
|
@@ -225,7 +231,6 @@ def point_mesh_distance(meshes, pcls, weighted=True):
|
|
225 |
|
226 |
|
227 |
class Evaluator:
|
228 |
-
|
229 |
def __init__(self, device):
|
230 |
|
231 |
self.render = Render(size=512, device=device)
|
@@ -253,8 +258,8 @@ class Evaluator:
|
|
253 |
self.render.meshes = self.tgt_mesh
|
254 |
tgt_normal_imgs = self.render.get_image(cam_type="four", bg="black")
|
255 |
|
256 |
-
src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0)
|
257 |
-
tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0)
|
258 |
src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True)
|
259 |
tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True)
|
260 |
|
@@ -274,8 +279,11 @@ class Evaluator:
|
|
274 |
# error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0
|
275 |
|
276 |
normal_img = Image.fromarray(
|
277 |
-
(
|
278 |
-
|
|
|
|
|
|
|
279 |
normal_img.save(normal_path)
|
280 |
|
281 |
return error
|
@@ -291,7 +299,9 @@ class Evaluator:
|
|
291 |
p2s_dist_all, _ = point_mesh_distance(self.src_mesh, tgt_points) * 100.0
|
292 |
p2s_dist = p2s_dist_all.sum()
|
293 |
|
294 |
-
chamfer_dist = (
|
|
|
|
|
295 |
|
296 |
return chamfer_dist, p2s_dist
|
297 |
|
|
|
37 |
"""
|
38 |
Torch autograd Function wrapper PointFaceDistance Cuda implementation
|
39 |
"""
|
|
|
40 |
@staticmethod
|
41 |
def forward(
|
42 |
ctx,
|
|
|
91 |
grad_dists = grad_dists.contiguous()
|
92 |
points, tris, idxs = ctx.saved_tensors
|
93 |
min_triangle_area = ctx.min_triangle_area
|
94 |
+
grad_points, grad_tris = _C.point_face_dist_backward(
|
95 |
+
points, tris, idxs, grad_dists, min_triangle_area
|
96 |
+
)
|
97 |
return grad_points, None, grad_tris, None, None, None
|
98 |
|
99 |
|
100 |
+
def _rand_barycentric_coords(
|
101 |
+
size1, size2, dtype: torch.dtype, device: torch.device
|
102 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
103 |
"""
|
104 |
Helper function to generate random barycentric coordinates which are uniformly
|
105 |
distributed over a triangle.
|
|
|
169 |
faces = meshes.faces_packed()
|
170 |
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
|
171 |
num_meshes = len(meshes)
|
172 |
+
num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
|
173 |
|
174 |
# Initialize samples tensor with fill value 0 for empty meshes.
|
175 |
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
|
176 |
|
177 |
# Only compute samples for non empty meshes
|
178 |
with torch.no_grad():
|
179 |
+
areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
|
180 |
max_faces = meshes.num_faces_per_mesh().max().item()
|
181 |
+
areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F)
|
182 |
|
183 |
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
|
184 |
+
samples_face_idxs = areas_padded.multinomial(
|
185 |
+
num_samples, replacement=True
|
186 |
+
) # (N, num_samples)
|
187 |
samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
|
188 |
|
189 |
# Randomly generate barycentric coords.
|
|
|
204 |
raise ValueError("meshes and pointclouds must be equal sized batches")
|
205 |
|
206 |
# packed representation for pointclouds
|
207 |
+
points = pcls.points_packed() # (P, 3)
|
208 |
points_first_idx = pcls.cloud_to_packed_first_idx()
|
209 |
max_points = pcls.num_points_per_cloud().max().item()
|
210 |
|
211 |
# packed representation for faces
|
212 |
verts_packed = meshes.verts_packed()
|
213 |
faces_packed = meshes.faces_packed()
|
214 |
+
tris = verts_packed[faces_packed] # (T, 3, 3)
|
215 |
tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
216 |
|
217 |
# point to face distance: shape (P,)
|
218 |
+
point_to_face, idxs = _PointFaceDistance.apply(
|
219 |
+
points, points_first_idx, tris, tris_first_idx, max_points, 5e-3
|
220 |
+
)
|
221 |
|
222 |
if weighted:
|
223 |
# weight each example by the inverse of number of points in the example
|
224 |
+
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
225 |
+
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
226 |
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
227 |
weights_p = 1.0 / weights_p.float()
|
228 |
point_to_face = torch.sqrt(point_to_face) * weights_p
|
|
|
231 |
|
232 |
|
233 |
class Evaluator:
|
|
|
234 |
def __init__(self, device):
|
235 |
|
236 |
self.render = Render(size=512, device=device)
|
|
|
258 |
self.render.meshes = self.tgt_mesh
|
259 |
tgt_normal_imgs = self.render.get_image(cam_type="four", bg="black")
|
260 |
|
261 |
+
src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
|
262 |
+
tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1]
|
263 |
src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True)
|
264 |
tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True)
|
265 |
|
|
|
279 |
# error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0
|
280 |
|
281 |
normal_img = Image.fromarray(
|
282 |
+
(
|
283 |
+
torch.cat([src_normal_arr, tgt_normal_arr],
|
284 |
+
dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0
|
285 |
+
).astype(np.uint8)
|
286 |
+
)
|
287 |
normal_img.save(normal_path)
|
288 |
|
289 |
return error
|
|
|
299 |
p2s_dist_all, _ = point_mesh_distance(self.src_mesh, tgt_points) * 100.0
|
300 |
p2s_dist = p2s_dist_all.sum()
|
301 |
|
302 |
+
chamfer_dist = (
|
303 |
+
point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist
|
304 |
+
) * 0.5
|
305 |
|
306 |
return chamfer_dist, p2s_dist
|
307 |
|
lib/dataset/NormalDataset.py
CHANGED
@@ -23,7 +23,6 @@ import torchvision.transforms as transforms
|
|
23 |
|
24 |
|
25 |
class NormalDataset:
|
26 |
-
|
27 |
def __init__(self, cfg, split="train"):
|
28 |
|
29 |
self.split = split
|
@@ -44,8 +43,7 @@ class NormalDataset:
|
|
44 |
if self.split != "train":
|
45 |
self.rotations = range(0, 360, 120)
|
46 |
else:
|
47 |
-
self.rotations = np.arange(0, 360, 360 //
|
48 |
-
self.opt.rotation_num).astype(np.int)
|
49 |
|
50 |
self.datasets_dict = {}
|
51 |
|
@@ -54,26 +52,29 @@ class NormalDataset:
|
|
54 |
dataset_dir = osp.join(self.root, dataset)
|
55 |
|
56 |
self.datasets_dict[dataset] = {
|
57 |
-
"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"),
|
58 |
-
dtype=str),
|
59 |
"scale": self.scales[dataset_id],
|
60 |
}
|
61 |
|
62 |
self.subject_list = self.get_subject_list(split)
|
63 |
|
64 |
# PIL to tensor
|
65 |
-
self.image_to_tensor = transforms.Compose(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
|
71 |
# PIL to tensor
|
72 |
-
self.mask_to_tensor = transforms.Compose(
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
77 |
|
78 |
def get_subject_list(self, split):
|
79 |
|
@@ -88,16 +89,12 @@ class NormalDataset:
|
|
88 |
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
|
89 |
|
90 |
if self.split != "test":
|
91 |
-
subject_list += subject_list[:self.bsize -
|
92 |
-
len(subject_list) % self.bsize]
|
93 |
print(colored(f"total: {len(subject_list)}", "yellow"))
|
94 |
|
95 |
-
bug_list = sorted(
|
96 |
-
np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
|
97 |
|
98 |
-
subject_list = [
|
99 |
-
subject for subject in subject_list if (subject not in bug_list)
|
100 |
-
]
|
101 |
|
102 |
# subject_list = ["thuman2/0008"]
|
103 |
return subject_list
|
@@ -113,48 +110,41 @@ class NormalDataset:
|
|
113 |
rotation = self.rotations[rid]
|
114 |
subject = self.subject_list[mid].split("/")[1]
|
115 |
dataset = self.subject_list[mid].split("/")[0]
|
116 |
-
render_folder = "/".join(
|
117 |
-
[dataset + f"_{self.opt.rotation_num}views", subject])
|
118 |
|
119 |
if not osp.exists(osp.join(self.root, render_folder)):
|
120 |
render_folder = "/".join([dataset + f"_36views", subject])
|
121 |
|
122 |
# setup paths
|
123 |
data_dict = {
|
124 |
-
"dataset":
|
125 |
-
|
126 |
-
"
|
127 |
-
|
128 |
-
"rotation"
|
129 |
-
rotation,
|
130 |
-
"scale":
|
131 |
-
self.datasets_dict[dataset]["scale"],
|
132 |
-
"image_path":
|
133 |
-
osp.join(self.root, render_folder, "render",
|
134 |
-
f"{rotation:03d}.png"),
|
135 |
}
|
136 |
|
137 |
# image/normal/depth loader
|
138 |
for name, channel in zip(self.in_total, self.in_total_dim):
|
139 |
|
140 |
if f"{name}_path" not in data_dict.keys():
|
141 |
-
data_dict.update(
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
]
|
158 |
|
159 |
for key in path_keys:
|
160 |
del data_dict[key]
|
@@ -172,10 +162,9 @@ class NormalDataset:
|
|
172 |
|
173 |
# simulate occlusion
|
174 |
if erasing:
|
175 |
-
mask = kornia.augmentation.RandomErasing(
|
176 |
-
|
177 |
-
|
178 |
-
keepdim=True)(mask)
|
179 |
image = (image * mask)[:channel]
|
180 |
|
181 |
return (image * (0.5 - inv) * 2.0).float()
|
|
|
23 |
|
24 |
|
25 |
class NormalDataset:
|
|
|
26 |
def __init__(self, cfg, split="train"):
|
27 |
|
28 |
self.split = split
|
|
|
43 |
if self.split != "train":
|
44 |
self.rotations = range(0, 360, 120)
|
45 |
else:
|
46 |
+
self.rotations = np.arange(0, 360, 360 // self.opt.rotation_num).astype(np.int)
|
|
|
47 |
|
48 |
self.datasets_dict = {}
|
49 |
|
|
|
52 |
dataset_dir = osp.join(self.root, dataset)
|
53 |
|
54 |
self.datasets_dict[dataset] = {
|
55 |
+
"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
|
|
|
56 |
"scale": self.scales[dataset_id],
|
57 |
}
|
58 |
|
59 |
self.subject_list = self.get_subject_list(split)
|
60 |
|
61 |
# PIL to tensor
|
62 |
+
self.image_to_tensor = transforms.Compose(
|
63 |
+
[
|
64 |
+
transforms.Resize(self.input_size),
|
65 |
+
transforms.ToTensor(),
|
66 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
67 |
+
]
|
68 |
+
)
|
69 |
|
70 |
# PIL to tensor
|
71 |
+
self.mask_to_tensor = transforms.Compose(
|
72 |
+
[
|
73 |
+
transforms.Resize(self.input_size),
|
74 |
+
transforms.ToTensor(),
|
75 |
+
transforms.Normalize((0.0, ), (1.0, )),
|
76 |
+
]
|
77 |
+
)
|
78 |
|
79 |
def get_subject_list(self, split):
|
80 |
|
|
|
89 |
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
|
90 |
|
91 |
if self.split != "test":
|
92 |
+
subject_list += subject_list[:self.bsize - len(subject_list) % self.bsize]
|
|
|
93 |
print(colored(f"total: {len(subject_list)}", "yellow"))
|
94 |
|
95 |
+
bug_list = sorted(np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
|
|
|
96 |
|
97 |
+
subject_list = [subject for subject in subject_list if (subject not in bug_list)]
|
|
|
|
|
98 |
|
99 |
# subject_list = ["thuman2/0008"]
|
100 |
return subject_list
|
|
|
110 |
rotation = self.rotations[rid]
|
111 |
subject = self.subject_list[mid].split("/")[1]
|
112 |
dataset = self.subject_list[mid].split("/")[0]
|
113 |
+
render_folder = "/".join([dataset + f"_{self.opt.rotation_num}views", subject])
|
|
|
114 |
|
115 |
if not osp.exists(osp.join(self.root, render_folder)):
|
116 |
render_folder = "/".join([dataset + f"_36views", subject])
|
117 |
|
118 |
# setup paths
|
119 |
data_dict = {
|
120 |
+
"dataset": dataset,
|
121 |
+
"subject": subject,
|
122 |
+
"rotation": rotation,
|
123 |
+
"scale": self.datasets_dict[dataset]["scale"],
|
124 |
+
"image_path": osp.join(self.root, render_folder, "render", f"{rotation:03d}.png"),
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
}
|
126 |
|
127 |
# image/normal/depth loader
|
128 |
for name, channel in zip(self.in_total, self.in_total_dim):
|
129 |
|
130 |
if f"{name}_path" not in data_dict.keys():
|
131 |
+
data_dict.update(
|
132 |
+
{
|
133 |
+
f"{name}_path":
|
134 |
+
osp.join(self.root, render_folder, name, f"{rotation:03d}.png")
|
135 |
+
}
|
136 |
+
)
|
137 |
+
|
138 |
+
data_dict.update(
|
139 |
+
{
|
140 |
+
name:
|
141 |
+
self.imagepath2tensor(
|
142 |
+
data_dict[f"{name}_path"], channel, inv=False, erasing=False
|
143 |
+
)
|
144 |
+
}
|
145 |
+
)
|
146 |
+
|
147 |
+
path_keys = [key for key in data_dict.keys() if "_path" in key or "_dir" in key]
|
148 |
|
149 |
for key in path_keys:
|
150 |
del data_dict[key]
|
|
|
162 |
|
163 |
# simulate occlusion
|
164 |
if erasing:
|
165 |
+
mask = kornia.augmentation.RandomErasing(
|
166 |
+
p=0.2, scale=(0.01, 0.2), ratio=(0.3, 3.3), keepdim=True
|
167 |
+
)(mask)
|
|
|
168 |
image = (image * mask)[:channel]
|
169 |
|
170 |
return (image * (0.5 - inv) * 2.0).float()
|
lib/dataset/NormalModule.py
CHANGED
@@ -22,7 +22,6 @@ import pytorch_lightning as pl
|
|
22 |
|
23 |
|
24 |
class NormalModule(pl.LightningDataModule):
|
25 |
-
|
26 |
def __init__(self, cfg):
|
27 |
super(NormalModule, self).__init__()
|
28 |
self.cfg = cfg
|
@@ -40,7 +39,7 @@ class NormalModule(pl.LightningDataModule):
|
|
40 |
self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
|
41 |
self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
|
42 |
self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
|
43 |
-
|
44 |
self.data_size = {
|
45 |
"train": len(self.train_dataset),
|
46 |
"val": len(self.val_dataset),
|
@@ -69,7 +68,7 @@ class NormalModule(pl.LightningDataModule):
|
|
69 |
)
|
70 |
|
71 |
return val_data_loader
|
72 |
-
|
73 |
def val_dataloader(self):
|
74 |
|
75 |
test_data_loader = DataLoader(
|
|
|
22 |
|
23 |
|
24 |
class NormalModule(pl.LightningDataModule):
|
|
|
25 |
def __init__(self, cfg):
|
26 |
super(NormalModule, self).__init__()
|
27 |
self.cfg = cfg
|
|
|
39 |
self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
|
40 |
self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
|
41 |
self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
|
42 |
+
|
43 |
self.data_size = {
|
44 |
"train": len(self.train_dataset),
|
45 |
"val": len(self.val_dataset),
|
|
|
68 |
)
|
69 |
|
70 |
return val_data_loader
|
71 |
+
|
72 |
def val_dataloader(self):
|
73 |
|
74 |
test_data_loader = DataLoader(
|
lib/dataset/PointFeat.py
CHANGED
@@ -6,7 +6,6 @@ from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection
|
|
6 |
|
7 |
|
8 |
class PointFeat:
|
9 |
-
|
10 |
def __init__(self, verts, faces):
|
11 |
|
12 |
# verts [B, N_vert, 3]
|
@@ -23,7 +22,10 @@ class PointFeat:
|
|
23 |
|
24 |
if verts.shape[1] == 10475:
|
25 |
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
|
26 |
-
mouth_faces = (
|
|
|
|
|
|
|
27 |
self.faces = torch.cat([faces, mouth_faces], dim=1).long()
|
28 |
|
29 |
self.verts = verts.float()
|
@@ -35,11 +37,15 @@ class PointFeat:
|
|
35 |
points = points.float()
|
36 |
residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False)
|
37 |
|
38 |
-
closest_triangles = torch.gather(
|
|
|
|
|
39 |
bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)
|
40 |
|
41 |
feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
|
42 |
-
closest_normals = torch.gather(
|
|
|
|
|
43 |
shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))
|
44 |
|
45 |
pts2shoot_normals = points - shoot_verts
|
|
|
6 |
|
7 |
|
8 |
class PointFeat:
|
|
|
9 |
def __init__(self, verts, faces):
|
10 |
|
11 |
# verts [B, N_vert, 3]
|
|
|
22 |
|
23 |
if verts.shape[1] == 10475:
|
24 |
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
|
25 |
+
mouth_faces = (
|
26 |
+
torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1,
|
27 |
+
1).to(self.device)
|
28 |
+
)
|
29 |
self.faces = torch.cat([faces, mouth_faces], dim=1).long()
|
30 |
|
31 |
self.verts = verts.float()
|
|
|
37 |
points = points.float()
|
38 |
residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False)
|
39 |
|
40 |
+
closest_triangles = torch.gather(
|
41 |
+
self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
|
42 |
+
).view(-1, 3, 3)
|
43 |
bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)
|
44 |
|
45 |
feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
|
46 |
+
closest_normals = torch.gather(
|
47 |
+
feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
|
48 |
+
).view(-1, 3, 3)
|
49 |
shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))
|
50 |
|
51 |
pts2shoot_normals = points - shoot_verts
|
lib/dataset/TestDataset.py
CHANGED
@@ -25,6 +25,7 @@ from lib.pixielib.utils.config import cfg as pixie_cfg
|
|
25 |
from lib.pixielib.pixie import PIXIE
|
26 |
from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
|
27 |
from lib.common.imutils import process_image
|
|
|
28 |
from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat
|
29 |
|
30 |
from lib.pymafx.core import path_config
|
@@ -36,8 +37,9 @@ from lib.dataset.body_model import TetraSMPLModel
|
|
36 |
from lib.dataset.mesh_util import get_visibility, SMPLX
|
37 |
import torch.nn.functional as F
|
38 |
from torchvision import transforms
|
|
|
|
|
39 |
import os.path as osp
|
40 |
-
import os
|
41 |
import torch
|
42 |
import glob
|
43 |
import numpy as np
|
@@ -48,7 +50,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
48 |
|
49 |
|
50 |
class TestDataset:
|
51 |
-
|
52 |
def __init__(self, cfg, device):
|
53 |
|
54 |
self.image_dir = cfg["image_dir"]
|
@@ -65,7 +66,9 @@ class TestDataset:
|
|
65 |
keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
|
66 |
img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp"]
|
67 |
|
68 |
-
self.subject_list = sorted(
|
|
|
|
|
69 |
|
70 |
# smpl related
|
71 |
self.smpl_data = SMPLX()
|
@@ -80,7 +83,16 @@ class TestDataset:
|
|
80 |
|
81 |
self.smpl_model = PIXIE_SMPLX(pixie_cfg.model).to(self.device)
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
self.render = Render(size=512, device=self.device)
|
86 |
|
@@ -90,7 +102,9 @@ class TestDataset:
|
|
90 |
def compute_vis_cmap(self, smpl_verts, smpl_faces):
|
91 |
|
92 |
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=-1)
|
93 |
-
smpl_vis = get_visibility(xy, z,
|
|
|
|
|
94 |
smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type).unsqueeze(0)
|
95 |
|
96 |
return {
|
@@ -109,7 +123,8 @@ class TestDataset:
|
|
109 |
depth_FB[:, ~depth_mask[0]] = 0.
|
110 |
|
111 |
# Important: index_long = depth_value - 1
|
112 |
-
index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res -
|
|
|
113 |
index_z_ceil = torch.ceil(index_z).long()
|
114 |
index_z_floor = torch.floor(index_z).long()
|
115 |
index_z_frac = torch.frac(index_z)
|
@@ -121,7 +136,7 @@ class TestDataset:
|
|
121 |
F.one_hot(index_z_floor[..., 1], self.vol_res) * (1.0 - index_z_frac[..., 1])
|
122 |
|
123 |
voxels[index_mask] *= 0
|
124 |
-
voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float()
|
125 |
|
126 |
return {
|
127 |
"depth_voxels": voxels.flip([
|
@@ -139,18 +154,25 @@ class TestDataset:
|
|
139 |
smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
|
140 |
|
141 |
verts = (
|
142 |
-
np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() +
|
|
|
|
|
143 |
faces = (
|
144 |
np.loadtxt(
|
145 |
osp.join(self.smpl_data.tedra_dir, "tetrahedrons_neutral_adult.txt"),
|
146 |
dtype=np.int32,
|
147 |
-
) - 1
|
|
|
148 |
|
149 |
pad_v_num = int(8000 - verts.shape[0])
|
150 |
pad_f_num = int(25100 - faces.shape[0])
|
151 |
|
152 |
-
verts = (
|
153 |
-
|
|
|
|
|
|
|
|
|
154 |
|
155 |
verts[:, 2] *= -1.0
|
156 |
|
@@ -168,7 +190,7 @@ class TestDataset:
|
|
168 |
img_path = self.subject_list[index]
|
169 |
img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
|
170 |
|
171 |
-
arr_dict = process_image(img_path, self.hps_type, self.single, 512)
|
172 |
arr_dict.update({"name": img_name})
|
173 |
|
174 |
with torch.no_grad():
|
@@ -179,7 +201,10 @@ class TestDataset:
|
|
179 |
preds_dict, _ = self.hps.forward(batch)
|
180 |
|
181 |
arr_dict["smpl_faces"] = (
|
182 |
-
torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(
|
|
|
|
|
|
|
183 |
arr_dict["type"] = self.smpl_type
|
184 |
|
185 |
if self.hps_type == "pymafx":
|
@@ -198,13 +223,16 @@ class TestDataset:
|
|
198 |
elif self.hps_type == "pixie":
|
199 |
arr_dict.update(preds_dict)
|
200 |
arr_dict["global_orient"] = preds_dict["global_pose"]
|
201 |
-
arr_dict["betas"] = preds_dict["shape"]
|
202 |
arr_dict["smpl_verts"] = preds_dict["vertices"]
|
203 |
scale, tranX, tranY = preds_dict["cam"].split(1, dim=1)
|
204 |
# 1.1435, 0.0128, 0.3520
|
205 |
|
206 |
arr_dict["scale"] = scale.unsqueeze(1)
|
207 |
-
arr_dict["trans"] = (
|
|
|
|
|
|
|
208 |
|
209 |
# data_dict info (key-shape):
|
210 |
# scale, tranX, tranY - tensor.float
|
@@ -230,4 +258,4 @@ class TestDataset:
|
|
230 |
|
231 |
# render optimized mesh (normal, T_normal, image [-1,1])
|
232 |
self.render.load_meshes(verts, faces)
|
233 |
-
return self.render.get_image(type="depth")
|
|
|
25 |
from lib.pixielib.pixie import PIXIE
|
26 |
from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX
|
27 |
from lib.common.imutils import process_image
|
28 |
+
from lib.common.train_util import Format
|
29 |
from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat
|
30 |
|
31 |
from lib.pymafx.core import path_config
|
|
|
37 |
from lib.dataset.mesh_util import get_visibility, SMPLX
|
38 |
import torch.nn.functional as F
|
39 |
from torchvision import transforms
|
40 |
+
from torchvision.models import detection
|
41 |
+
|
42 |
import os.path as osp
|
|
|
43 |
import torch
|
44 |
import glob
|
45 |
import numpy as np
|
|
|
50 |
|
51 |
|
52 |
class TestDataset:
|
|
|
53 |
def __init__(self, cfg, device):
|
54 |
|
55 |
self.image_dir = cfg["image_dir"]
|
|
|
66 |
keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
|
67 |
img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp"]
|
68 |
|
69 |
+
self.subject_list = sorted(
|
70 |
+
[item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False
|
71 |
+
)
|
72 |
|
73 |
# smpl related
|
74 |
self.smpl_data = SMPLX()
|
|
|
83 |
|
84 |
self.smpl_model = PIXIE_SMPLX(pixie_cfg.model).to(self.device)
|
85 |
|
86 |
+
self.detector = detection.maskrcnn_resnet50_fpn(
|
87 |
+
weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights
|
88 |
+
)
|
89 |
+
self.detector.eval()
|
90 |
+
|
91 |
+
print(
|
92 |
+
colored(
|
93 |
+
f"SMPL-X estimate with {Format.start} {self.hps_type.upper()} {Format.end}", "green"
|
94 |
+
)
|
95 |
+
)
|
96 |
|
97 |
self.render = Render(size=512, device=self.device)
|
98 |
|
|
|
102 |
def compute_vis_cmap(self, smpl_verts, smpl_faces):
|
103 |
|
104 |
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=-1)
|
105 |
+
smpl_vis = get_visibility(xy, z,
|
106 |
+
torch.as_tensor(smpl_faces).long()[:, :,
|
107 |
+
[0, 2, 1]]).unsqueeze(-1)
|
108 |
smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type).unsqueeze(0)
|
109 |
|
110 |
return {
|
|
|
123 |
depth_FB[:, ~depth_mask[0]] = 0.
|
124 |
|
125 |
# Important: index_long = depth_value - 1
|
126 |
+
index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res -
|
127 |
+
1).permute(1, 2, 0)
|
128 |
index_z_ceil = torch.ceil(index_z).long()
|
129 |
index_z_floor = torch.floor(index_z).long()
|
130 |
index_z_frac = torch.frac(index_z)
|
|
|
136 |
F.one_hot(index_z_floor[..., 1], self.vol_res) * (1.0 - index_z_frac[..., 1])
|
137 |
|
138 |
voxels[index_mask] *= 0
|
139 |
+
voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1]
|
140 |
|
141 |
return {
|
142 |
"depth_voxels": voxels.flip([
|
|
|
154 |
smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0])
|
155 |
|
156 |
verts = (
|
157 |
+
np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() +
|
158 |
+
trans.detach().cpu().numpy()
|
159 |
+
)
|
160 |
faces = (
|
161 |
np.loadtxt(
|
162 |
osp.join(self.smpl_data.tedra_dir, "tetrahedrons_neutral_adult.txt"),
|
163 |
dtype=np.int32,
|
164 |
+
) - 1
|
165 |
+
)
|
166 |
|
167 |
pad_v_num = int(8000 - verts.shape[0])
|
168 |
pad_f_num = int(25100 - faces.shape[0])
|
169 |
|
170 |
+
verts = (
|
171 |
+
np.pad(verts, ((0, pad_v_num),
|
172 |
+
(0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5
|
173 |
+
)
|
174 |
+
faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant",
|
175 |
+
constant_values=0.0).astype(np.int32)
|
176 |
|
177 |
verts[:, 2] *= -1.0
|
178 |
|
|
|
190 |
img_path = self.subject_list[index]
|
191 |
img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
|
192 |
|
193 |
+
arr_dict = process_image(img_path, self.hps_type, self.single, 512, self.detector)
|
194 |
arr_dict.update({"name": img_name})
|
195 |
|
196 |
with torch.no_grad():
|
|
|
201 |
preds_dict, _ = self.hps.forward(batch)
|
202 |
|
203 |
arr_dict["smpl_faces"] = (
|
204 |
+
torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(
|
205 |
+
self.device
|
206 |
+
)
|
207 |
+
)
|
208 |
arr_dict["type"] = self.smpl_type
|
209 |
|
210 |
if self.hps_type == "pymafx":
|
|
|
223 |
elif self.hps_type == "pixie":
|
224 |
arr_dict.update(preds_dict)
|
225 |
arr_dict["global_orient"] = preds_dict["global_pose"]
|
226 |
+
arr_dict["betas"] = preds_dict["shape"] #200
|
227 |
arr_dict["smpl_verts"] = preds_dict["vertices"]
|
228 |
scale, tranX, tranY = preds_dict["cam"].split(1, dim=1)
|
229 |
# 1.1435, 0.0128, 0.3520
|
230 |
|
231 |
arr_dict["scale"] = scale.unsqueeze(1)
|
232 |
+
arr_dict["trans"] = (
|
233 |
+
torch.cat([tranX, tranY, torch.zeros_like(tranX)],
|
234 |
+
dim=1).unsqueeze(1).to(self.device).float()
|
235 |
+
)
|
236 |
|
237 |
# data_dict info (key-shape):
|
238 |
# scale, tranX, tranY - tensor.float
|
|
|
258 |
|
259 |
# render optimized mesh (normal, T_normal, image [-1,1])
|
260 |
self.render.load_meshes(verts, faces)
|
261 |
+
return self.render.get_image(type="depth")
|
lib/dataset/body_model.py
CHANGED
@@ -21,7 +21,6 @@ import os
|
|
21 |
|
22 |
|
23 |
class SMPLModel:
|
24 |
-
|
25 |
def __init__(self, model_path, age):
|
26 |
"""
|
27 |
SMPL model.
|
@@ -49,20 +48,16 @@ class SMPLModel:
|
|
49 |
|
50 |
if age == "kid":
|
51 |
v_template_smil = np.load(
|
52 |
-
os.path.join(os.path.dirname(model_path),
|
53 |
-
|
54 |
v_template_smil -= np.mean(v_template_smil, axis=0)
|
55 |
-
v_template_diff = np.expand_dims(v_template_smil - self.v_template,
|
56 |
-
axis=2)
|
57 |
self.shapedirs = np.concatenate(
|
58 |
-
(self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
|
59 |
-
|
60 |
self.beta_shape[0] += 1
|
61 |
|
62 |
-
id_to_col = {
|
63 |
-
self.kintree_table[1, i]: i
|
64 |
-
for i in range(self.kintree_table.shape[1])
|
65 |
-
}
|
66 |
self.parent = {
|
67 |
i: id_to_col[self.kintree_table[0, i]]
|
68 |
for i in range(1, self.kintree_table.shape[1])
|
@@ -121,33 +116,30 @@ class SMPLModel:
|
|
121 |
pose_cube = self.pose.reshape((-1, 1, 3))
|
122 |
# rotation matrix for each joint
|
123 |
self.R = self.rodrigues(pose_cube)
|
124 |
-
I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
|
125 |
-
(self.R.shape[0] - 1, 3, 3))
|
126 |
lrotmin = (self.R[1:] - I_cube).ravel()
|
127 |
# how pose affect body shape in zero pose
|
128 |
v_posed = v_shaped + self.posedirs.dot(lrotmin)
|
129 |
# world transformation of each joint
|
130 |
G = np.empty((self.kintree_table.shape[1], 4, 4))
|
131 |
-
G[0] = self.with_zeros(
|
132 |
-
np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
|
133 |
for i in range(1, self.kintree_table.shape[1]):
|
134 |
G[i] = G[self.parent[i]].dot(
|
135 |
self.with_zeros(
|
136 |
-
np.hstack(
|
137 |
-
|
138 |
-
|
139 |
-
[3, 1])),
|
140 |
-
|
|
|
|
|
|
|
141 |
# remove the transformation due to the rest pose
|
142 |
-
G = G - self.pack(
|
143 |
-
np.matmul(
|
144 |
-
G,
|
145 |
-
np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
|
146 |
# transformation of each vertex
|
147 |
T = np.tensordot(self.weights, G, axes=[[1], [0]])
|
148 |
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
|
149 |
-
v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
|
150 |
-
4])[:, :3]
|
151 |
self.verts = v + self.trans.reshape([1, 3])
|
152 |
self.G = G
|
153 |
|
@@ -171,19 +163,20 @@ class SMPLModel:
|
|
171 |
r_hat = r / theta
|
172 |
cos = np.cos(theta)
|
173 |
z_stick = np.zeros(theta.shape[0])
|
174 |
-
m = np.dstack(
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
187 |
A = np.transpose(r_hat, axes=[0, 2, 1])
|
188 |
B = r_hat
|
189 |
dot = np.matmul(A, B)
|
@@ -238,12 +231,7 @@ class SMPLModel:
|
|
238 |
|
239 |
|
240 |
class TetraSMPLModel:
|
241 |
-
|
242 |
-
def __init__(self,
|
243 |
-
model_path,
|
244 |
-
model_addition_path,
|
245 |
-
age="adult",
|
246 |
-
v_template=None):
|
247 |
"""
|
248 |
SMPL model.
|
249 |
|
@@ -276,10 +264,7 @@ class TetraSMPLModel:
|
|
276 |
self.posedirs_added = params_added["posedirs_added"]
|
277 |
self.tetrahedrons = params_added["tetrahedrons"]
|
278 |
|
279 |
-
id_to_col = {
|
280 |
-
self.kintree_table[1, i]: i
|
281 |
-
for i in range(self.kintree_table.shape[1])
|
282 |
-
}
|
283 |
self.parent = {
|
284 |
i: id_to_col[self.kintree_table[0, i]]
|
285 |
for i in range(1, self.kintree_table.shape[1])
|
@@ -291,14 +276,13 @@ class TetraSMPLModel:
|
|
291 |
|
292 |
if age == "kid":
|
293 |
v_template_smil = np.load(
|
294 |
-
os.path.join(os.path.dirname(model_path),
|
295 |
-
|
296 |
v_template_smil -= np.mean(v_template_smil, axis=0)
|
297 |
-
v_template_diff = np.expand_dims(v_template_smil - self.v_template,
|
298 |
-
axis=2)
|
299 |
self.shapedirs = np.concatenate(
|
300 |
-
(self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
|
301 |
-
|
302 |
self.beta_shape[0] += 1
|
303 |
|
304 |
self.pose = np.zeros(self.pose_shape)
|
@@ -356,50 +340,42 @@ class TetraSMPLModel:
|
|
356 |
"""
|
357 |
# how beta affect body shape
|
358 |
v_shaped = self.shapedirs.dot(self.beta) + self.v_template
|
359 |
-
v_shaped_added = self.shapedirs_added.dot(
|
360 |
-
self.beta) + self.v_template_added
|
361 |
# joints location
|
362 |
self.J = self.J_regressor.dot(v_shaped)
|
363 |
pose_cube = self.pose.reshape((-1, 1, 3))
|
364 |
# rotation matrix for each joint
|
365 |
self.R = self.rodrigues(pose_cube)
|
366 |
-
I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
|
367 |
-
(self.R.shape[0] - 1, 3, 3))
|
368 |
lrotmin = (self.R[1:] - I_cube).ravel()
|
369 |
# how pose affect body shape in zero pose
|
370 |
v_posed = v_shaped + self.posedirs.dot(lrotmin)
|
371 |
v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
|
372 |
# world transformation of each joint
|
373 |
G = np.empty((self.kintree_table.shape[1], 4, 4))
|
374 |
-
G[0] = self.with_zeros(
|
375 |
-
np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
|
376 |
for i in range(1, self.kintree_table.shape[1]):
|
377 |
G[i] = G[self.parent[i]].dot(
|
378 |
self.with_zeros(
|
379 |
-
np.hstack(
|
380 |
-
|
381 |
-
|
382 |
-
[3, 1])),
|
383 |
-
|
|
|
|
|
|
|
384 |
# remove the transformation due to the rest pose
|
385 |
-
G = G - self.pack(
|
386 |
-
np.matmul(
|
387 |
-
G,
|
388 |
-
np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
|
389 |
self.G = G
|
390 |
# transformation of each vertex
|
391 |
T = np.tensordot(self.weights, G, axes=[[1], [0]])
|
392 |
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
|
393 |
-
v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
|
394 |
-
4])[:, :3]
|
395 |
self.verts = v + self.trans.reshape([1, 3])
|
396 |
T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
|
397 |
-
rest_shape_added_h = np.hstack(
|
398 |
-
|
399 |
-
v_added = np.matmul(T_added,
|
400 |
-
rest_shape_added_h.reshape([-1, 4,
|
401 |
-
1])).reshape([-1, 4
|
402 |
-
])[:, :3]
|
403 |
self.verts_added = v_added + self.trans.reshape([1, 3])
|
404 |
|
405 |
def rodrigues(self, r):
|
@@ -422,19 +398,20 @@ class TetraSMPLModel:
|
|
422 |
r_hat = r / theta
|
423 |
cos = np.cos(theta)
|
424 |
z_stick = np.zeros(theta.shape[0])
|
425 |
-
m = np.dstack(
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
438 |
A = np.transpose(r_hat, axes=[0, 2, 1])
|
439 |
B = r_hat
|
440 |
dot = np.matmul(A, B)
|
|
|
21 |
|
22 |
|
23 |
class SMPLModel:
|
|
|
24 |
def __init__(self, model_path, age):
|
25 |
"""
|
26 |
SMPL model.
|
|
|
48 |
|
49 |
if age == "kid":
|
50 |
v_template_smil = np.load(
|
51 |
+
os.path.join(os.path.dirname(model_path), "smpl/smpl_kid_template.npy")
|
52 |
+
)
|
53 |
v_template_smil -= np.mean(v_template_smil, axis=0)
|
54 |
+
v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2)
|
|
|
55 |
self.shapedirs = np.concatenate(
|
56 |
+
(self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2
|
57 |
+
)
|
58 |
self.beta_shape[0] += 1
|
59 |
|
60 |
+
id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])}
|
|
|
|
|
|
|
61 |
self.parent = {
|
62 |
i: id_to_col[self.kintree_table[0, i]]
|
63 |
for i in range(1, self.kintree_table.shape[1])
|
|
|
116 |
pose_cube = self.pose.reshape((-1, 1, 3))
|
117 |
# rotation matrix for each joint
|
118 |
self.R = self.rodrigues(pose_cube)
|
119 |
+
I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
|
|
|
120 |
lrotmin = (self.R[1:] - I_cube).ravel()
|
121 |
# how pose affect body shape in zero pose
|
122 |
v_posed = v_shaped + self.posedirs.dot(lrotmin)
|
123 |
# world transformation of each joint
|
124 |
G = np.empty((self.kintree_table.shape[1], 4, 4))
|
125 |
+
G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
|
|
|
126 |
for i in range(1, self.kintree_table.shape[1]):
|
127 |
G[i] = G[self.parent[i]].dot(
|
128 |
self.with_zeros(
|
129 |
+
np.hstack(
|
130 |
+
[
|
131 |
+
self.R[i],
|
132 |
+
((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
|
133 |
+
]
|
134 |
+
)
|
135 |
+
)
|
136 |
+
)
|
137 |
# remove the transformation due to the rest pose
|
138 |
+
G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
|
|
|
|
|
|
|
139 |
# transformation of each vertex
|
140 |
T = np.tensordot(self.weights, G, axes=[[1], [0]])
|
141 |
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
|
142 |
+
v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
|
|
|
143 |
self.verts = v + self.trans.reshape([1, 3])
|
144 |
self.G = G
|
145 |
|
|
|
163 |
r_hat = r / theta
|
164 |
cos = np.cos(theta)
|
165 |
z_stick = np.zeros(theta.shape[0])
|
166 |
+
m = np.dstack(
|
167 |
+
[
|
168 |
+
z_stick,
|
169 |
+
-r_hat[:, 0, 2],
|
170 |
+
r_hat[:, 0, 1],
|
171 |
+
r_hat[:, 0, 2],
|
172 |
+
z_stick,
|
173 |
+
-r_hat[:, 0, 0],
|
174 |
+
-r_hat[:, 0, 1],
|
175 |
+
r_hat[:, 0, 0],
|
176 |
+
z_stick,
|
177 |
+
]
|
178 |
+
).reshape([-1, 3, 3])
|
179 |
+
i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
|
180 |
A = np.transpose(r_hat, axes=[0, 2, 1])
|
181 |
B = r_hat
|
182 |
dot = np.matmul(A, B)
|
|
|
231 |
|
232 |
|
233 |
class TetraSMPLModel:
|
234 |
+
def __init__(self, model_path, model_addition_path, age="adult", v_template=None):
|
|
|
|
|
|
|
|
|
|
|
235 |
"""
|
236 |
SMPL model.
|
237 |
|
|
|
264 |
self.posedirs_added = params_added["posedirs_added"]
|
265 |
self.tetrahedrons = params_added["tetrahedrons"]
|
266 |
|
267 |
+
id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])}
|
|
|
|
|
|
|
268 |
self.parent = {
|
269 |
i: id_to_col[self.kintree_table[0, i]]
|
270 |
for i in range(1, self.kintree_table.shape[1])
|
|
|
276 |
|
277 |
if age == "kid":
|
278 |
v_template_smil = np.load(
|
279 |
+
os.path.join(os.path.dirname(model_path), "smpl_kid_template.npy")
|
280 |
+
)
|
281 |
v_template_smil -= np.mean(v_template_smil, axis=0)
|
282 |
+
v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2)
|
|
|
283 |
self.shapedirs = np.concatenate(
|
284 |
+
(self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2
|
285 |
+
)
|
286 |
self.beta_shape[0] += 1
|
287 |
|
288 |
self.pose = np.zeros(self.pose_shape)
|
|
|
340 |
"""
|
341 |
# how beta affect body shape
|
342 |
v_shaped = self.shapedirs.dot(self.beta) + self.v_template
|
343 |
+
v_shaped_added = self.shapedirs_added.dot(self.beta) + self.v_template_added
|
|
|
344 |
# joints location
|
345 |
self.J = self.J_regressor.dot(v_shaped)
|
346 |
pose_cube = self.pose.reshape((-1, 1, 3))
|
347 |
# rotation matrix for each joint
|
348 |
self.R = self.rodrigues(pose_cube)
|
349 |
+
I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
|
|
|
350 |
lrotmin = (self.R[1:] - I_cube).ravel()
|
351 |
# how pose affect body shape in zero pose
|
352 |
v_posed = v_shaped + self.posedirs.dot(lrotmin)
|
353 |
v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
|
354 |
# world transformation of each joint
|
355 |
G = np.empty((self.kintree_table.shape[1], 4, 4))
|
356 |
+
G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
|
|
|
357 |
for i in range(1, self.kintree_table.shape[1]):
|
358 |
G[i] = G[self.parent[i]].dot(
|
359 |
self.with_zeros(
|
360 |
+
np.hstack(
|
361 |
+
[
|
362 |
+
self.R[i],
|
363 |
+
((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])),
|
364 |
+
]
|
365 |
+
)
|
366 |
+
)
|
367 |
+
)
|
368 |
# remove the transformation due to the rest pose
|
369 |
+
G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
|
|
|
|
|
|
|
370 |
self.G = G
|
371 |
# transformation of each vertex
|
372 |
T = np.tensordot(self.weights, G, axes=[[1], [0]])
|
373 |
rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
|
374 |
+
v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
|
|
|
375 |
self.verts = v + self.trans.reshape([1, 3])
|
376 |
T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
|
377 |
+
rest_shape_added_h = np.hstack((v_posed_added, np.ones([v_posed_added.shape[0], 1])))
|
378 |
+
v_added = np.matmul(T_added, rest_shape_added_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
|
|
|
|
|
|
|
|
|
379 |
self.verts_added = v_added + self.trans.reshape([1, 3])
|
380 |
|
381 |
def rodrigues(self, r):
|
|
|
398 |
r_hat = r / theta
|
399 |
cos = np.cos(theta)
|
400 |
z_stick = np.zeros(theta.shape[0])
|
401 |
+
m = np.dstack(
|
402 |
+
[
|
403 |
+
z_stick,
|
404 |
+
-r_hat[:, 0, 2],
|
405 |
+
r_hat[:, 0, 1],
|
406 |
+
r_hat[:, 0, 2],
|
407 |
+
z_stick,
|
408 |
+
-r_hat[:, 0, 0],
|
409 |
+
-r_hat[:, 0, 1],
|
410 |
+
r_hat[:, 0, 0],
|
411 |
+
z_stick,
|
412 |
+
]
|
413 |
+
).reshape([-1, 3, 3])
|
414 |
+
i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])
|
415 |
A = np.transpose(r_hat, axes=[0, 2, 1])
|
416 |
B = r_hat
|
417 |
dot = np.matmul(A, B)
|
lib/dataset/mesh_util.py
CHANGED
@@ -14,32 +14,33 @@
|
|
14 |
#
|
15 |
# Contact: [email protected]
|
16 |
|
|
|
17 |
import numpy as np
|
18 |
-
import cv2
|
19 |
-
import pymeshlab
|
20 |
import torch
|
21 |
import torchvision
|
22 |
import trimesh
|
23 |
-
import
|
24 |
-
|
25 |
import os.path as osp
|
26 |
import _pickle as cPickle
|
|
|
27 |
from scipy.spatial import cKDTree
|
28 |
|
29 |
from pytorch3d.structures import Meshes
|
30 |
import torch.nn.functional as F
|
31 |
import lib.smplx as smplx
|
|
|
32 |
from pytorch3d.renderer.mesh import rasterize_meshes
|
33 |
from PIL import Image, ImageFont, ImageDraw
|
34 |
from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
|
35 |
-
import tinyobjloader
|
36 |
|
37 |
-
from lib.common.imutils import uncrop
|
38 |
-
from lib.common.render_utils import Pytorch3dRasterizer
|
39 |
|
|
|
|
|
|
|
40 |
|
41 |
-
class SMPLX:
|
42 |
|
|
|
43 |
def __init__(self):
|
44 |
|
45 |
self.current_dir = osp.join(osp.dirname(__file__), "../../data/smpl_related")
|
@@ -54,10 +55,14 @@ class SMPLX:
|
|
54 |
|
55 |
self.smplx_eyeball_fid_path = osp.join(self.current_dir, "smpl_data/eyeball_fid.npy")
|
56 |
self.smplx_fill_mouth_fid_path = osp.join(self.current_dir, "smpl_data/fill_mouth_fid.npy")
|
57 |
-
self.smplx_flame_vid_path = osp.join(
|
|
|
|
|
58 |
self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl")
|
59 |
self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy")
|
60 |
-
self.smplx_vertex_lmkid_path = osp.join(
|
|
|
|
|
61 |
|
62 |
self.smplx_faces = np.load(self.smplx_faces_path)
|
63 |
self.smplx_verts = np.load(self.smplx_verts_path)
|
@@ -68,84 +73,51 @@ class SMPLX:
|
|
68 |
self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path)
|
69 |
self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path)
|
70 |
self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True)
|
71 |
-
self.smplx_mano_vid = np.concatenate(
|
|
|
|
|
72 |
self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True)
|
73 |
self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)]
|
74 |
|
75 |
# hands
|
76 |
-
self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(
|
|
|
|
|
77 |
# face
|
78 |
-
self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(
|
79 |
-
0, torch.tensor(self.smplx_front_flame_vid), 1.0
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
|
83 |
self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, "rb"))
|
84 |
|
85 |
self.model_dir = osp.join(self.current_dir, "models")
|
86 |
self.tedra_dir = osp.join(self.current_dir, "../tedra_data")
|
87 |
|
88 |
-
self.ghum_smpl_pairs = torch.tensor(
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
(13, 18),
|
97 |
-
(14, 19),
|
98 |
-
(15, 20),
|
99 |
-
(16, 21),
|
100 |
-
(17, 39),
|
101 |
-
(18, 44),
|
102 |
-
(19, 36),
|
103 |
-
(20, 41),
|
104 |
-
(21, 35),
|
105 |
-
(22, 40),
|
106 |
-
(23, 1),
|
107 |
-
(24, 2),
|
108 |
-
(25, 4),
|
109 |
-
(26, 5),
|
110 |
-
(27, 7),
|
111 |
-
(28, 8),
|
112 |
-
(29, 31),
|
113 |
-
(30, 34),
|
114 |
-
(31, 29),
|
115 |
-
(32, 32),
|
116 |
-
]).long()
|
117 |
|
118 |
# smpl-smplx correspondence
|
119 |
self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
|
120 |
self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68]
|
121 |
-
self.smpl_joint_ids_45 =
|
122 |
-
|
123 |
-
self.extra_joint_ids = (
|
124 |
-
|
125 |
-
61,
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
57,
|
132 |
-
56,
|
133 |
-
64,
|
134 |
-
59,
|
135 |
-
67,
|
136 |
-
75,
|
137 |
-
70,
|
138 |
-
65,
|
139 |
-
60,
|
140 |
-
61,
|
141 |
-
63,
|
142 |
-
62,
|
143 |
-
76,
|
144 |
-
71,
|
145 |
-
72,
|
146 |
-
74,
|
147 |
-
73,
|
148 |
-
]) + 68)
|
149 |
|
150 |
self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist())
|
151 |
|
@@ -222,27 +194,6 @@ def load_fit_body(fitted_path, scale, smpl_type="smplx", smpl_gender="neutral",
|
|
222 |
return smpl_mesh, smpl_joints
|
223 |
|
224 |
|
225 |
-
def create_grid_points_from_xyz_bounds(bound, res):
|
226 |
-
|
227 |
-
min_x, max_x, min_y, max_y, min_z, max_z = bound
|
228 |
-
x = torch.linspace(min_x, max_x, res)
|
229 |
-
y = torch.linspace(min_y, max_y, res)
|
230 |
-
z = torch.linspace(min_z, max_z, res)
|
231 |
-
X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
|
232 |
-
|
233 |
-
return torch.stack([X, Y, Z], dim=-1)
|
234 |
-
|
235 |
-
|
236 |
-
def create_grid_points_from_xy_bounds(bound, res):
|
237 |
-
|
238 |
-
min_x, max_x, min_y, max_y = bound
|
239 |
-
x = torch.linspace(min_x, max_x, res)
|
240 |
-
y = torch.linspace(min_y, max_y, res)
|
241 |
-
X, Y = torch.meshgrid(x, y, indexing='ij')
|
242 |
-
|
243 |
-
return torch.stack([X, Y], dim=-1)
|
244 |
-
|
245 |
-
|
246 |
def apply_face_mask(mesh, face_mask):
|
247 |
|
248 |
mesh.update_faces(face_mask)
|
@@ -277,7 +228,8 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
|
|
277 |
|
278 |
part_extractor = PointFeat(
|
279 |
torch.tensor(part_mesh.vertices).unsqueeze(0).to(device),
|
280 |
-
torch.tensor(part_mesh.faces).unsqueeze(0).to(device)
|
|
|
281 |
|
282 |
(part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device))
|
283 |
|
@@ -286,12 +238,20 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
|
|
286 |
if region == "hand":
|
287 |
_, idx = smpl_tree.query(full_mesh.vertices, k=1)
|
288 |
full_lmkid = SMPL_container.smplx_vertex_lmkid[idx]
|
289 |
-
remove_mask = torch.logical_and(
|
|
|
|
|
|
|
290 |
|
291 |
elif region == "face":
|
292 |
_, idx = smpl_tree.query(full_mesh.vertices, k=5)
|
293 |
-
face_space_mask = torch.isin(
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1)
|
297 |
full_mesh.update_faces(BNI_part_mask.detach().cpu())
|
@@ -303,109 +263,6 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr
|
|
303 |
return full_mesh
|
304 |
|
305 |
|
306 |
-
def cross(triangles):
|
307 |
-
"""
|
308 |
-
Returns the cross product of two edges from input triangles
|
309 |
-
Parameters
|
310 |
-
--------------
|
311 |
-
triangles: (n, 3, 3) float
|
312 |
-
Vertices of triangles
|
313 |
-
Returns
|
314 |
-
--------------
|
315 |
-
crosses : (n, 3) float
|
316 |
-
Cross product of two edge vectors
|
317 |
-
"""
|
318 |
-
vectors = np.diff(triangles, axis=1)
|
319 |
-
crosses = np.cross(vectors[:, 0], vectors[:, 1])
|
320 |
-
return crosses
|
321 |
-
|
322 |
-
|
323 |
-
def tri_area(triangles=None, crosses=None, sum=False):
|
324 |
-
"""
|
325 |
-
Calculates the sum area of input triangles
|
326 |
-
Parameters
|
327 |
-
----------
|
328 |
-
triangles : (n, 3, 3) float
|
329 |
-
Vertices of triangles
|
330 |
-
crosses : (n, 3) float or None
|
331 |
-
As a speedup don't re- compute cross products
|
332 |
-
sum : bool
|
333 |
-
Return summed area or individual triangle area
|
334 |
-
Returns
|
335 |
-
----------
|
336 |
-
area : (n,) float or float
|
337 |
-
Individual or summed area depending on `sum` argument
|
338 |
-
"""
|
339 |
-
if crosses is None:
|
340 |
-
crosses = cross(triangles)
|
341 |
-
area = (np.sum(crosses**2, axis=1)**.5) * .5
|
342 |
-
if sum:
|
343 |
-
return np.sum(area)
|
344 |
-
return area
|
345 |
-
|
346 |
-
|
347 |
-
def sample_surface(triangles, count, area=None):
|
348 |
-
"""
|
349 |
-
Sample the surface of a mesh, returning the specified
|
350 |
-
number of points
|
351 |
-
For individual triangle sampling uses this method:
|
352 |
-
http://mathworld.wolfram.com/TrianglePointPicking.html
|
353 |
-
Parameters
|
354 |
-
---------
|
355 |
-
triangles : (n, 3, 3) float
|
356 |
-
Vertices of triangles
|
357 |
-
count : int
|
358 |
-
Number of points to return
|
359 |
-
Returns
|
360 |
-
---------
|
361 |
-
samples : (count, 3) float
|
362 |
-
Points in space on the surface of mesh
|
363 |
-
face_index : (count,) int
|
364 |
-
Indices of faces for each sampled point
|
365 |
-
"""
|
366 |
-
|
367 |
-
# len(mesh.faces) float, array of the areas
|
368 |
-
# of each face of the mesh
|
369 |
-
if area is None:
|
370 |
-
area = tri_area(triangles)
|
371 |
-
|
372 |
-
# total area (float)
|
373 |
-
area_sum = np.sum(area)
|
374 |
-
# cumulative area (len(mesh.faces))
|
375 |
-
area_cum = np.cumsum(area)
|
376 |
-
face_pick = np.random.random(count) * area_sum
|
377 |
-
face_index = np.searchsorted(area_cum, face_pick)
|
378 |
-
|
379 |
-
# pull triangles into the form of an origin + 2 vectors
|
380 |
-
tri_origins = triangles[:, 0]
|
381 |
-
tri_vectors = triangles[:, 1:].copy()
|
382 |
-
tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))
|
383 |
-
|
384 |
-
# pull the vectors for the faces we are going to sample from
|
385 |
-
tri_origins = tri_origins[face_index]
|
386 |
-
tri_vectors = tri_vectors[face_index]
|
387 |
-
|
388 |
-
# randomly generate two 0-1 scalar components to multiply edge vectors by
|
389 |
-
random_lengths = np.random.random((len(tri_vectors), 2, 1))
|
390 |
-
|
391 |
-
# points will be distributed on a quadrilateral if we use 2 0-1 samples
|
392 |
-
# if the two scalar components sum less than 1.0 the point will be
|
393 |
-
# inside the triangle, so we find vectors longer than 1.0 and
|
394 |
-
# transform them to be inside the triangle
|
395 |
-
random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0
|
396 |
-
random_lengths[random_test] -= 1.0
|
397 |
-
random_lengths = np.abs(random_lengths)
|
398 |
-
|
399 |
-
# multiply triangle edge vectors by the random lengths and sum
|
400 |
-
sample_vector = (tri_vectors * random_lengths).sum(axis=1)
|
401 |
-
|
402 |
-
# finally, offset by the origin to generate
|
403 |
-
# (n,3) points in space on the triangle
|
404 |
-
samples = torch.tensor(sample_vector + tri_origins).float()
|
405 |
-
|
406 |
-
return samples, face_index
|
407 |
-
|
408 |
-
|
409 |
def obj_loader(path, with_uv=True):
|
410 |
# Create reader.
|
411 |
reader = tinyobjloader.ObjReader()
|
@@ -424,8 +281,8 @@ def obj_loader(path, with_uv=True):
|
|
424 |
f_vt = tri[:, [2, 5, 8]]
|
425 |
|
426 |
if with_uv:
|
427 |
-
face_uvs = vt[f_vt].mean(axis=1)
|
428 |
-
vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32)
|
429 |
vert_uvs[f_v.reshape(-1)] = vt[f_vt.reshape(-1)]
|
430 |
|
431 |
return v, f_v, vert_uvs, face_uvs
|
@@ -434,7 +291,6 @@ def obj_loader(path, with_uv=True):
|
|
434 |
|
435 |
|
436 |
class HoppeMesh:
|
437 |
-
|
438 |
def __init__(self, verts, faces, uvs=None, texture=None):
|
439 |
"""
|
440 |
The HoppeSDF calculates signed distance towards a predefined oriented point cloud
|
@@ -459,34 +315,20 @@ class HoppeMesh:
|
|
459 |
- points: [n, 3]
|
460 |
- return: [n, 4] rgba
|
461 |
"""
|
462 |
-
triangles = self.verts[faces]
|
463 |
-
barycentric = trimesh.triangles.points_to_barycentric(triangles, points)
|
464 |
-
vert_colors = self.vertex_colors[faces]
|
465 |
point_colors = torch.tensor((barycentric[:, :, None] * vert_colors).sum(axis=1)).float()
|
466 |
return point_colors
|
467 |
|
468 |
def triangles(self):
|
469 |
-
return self.verts[self.faces].numpy()
|
470 |
|
471 |
|
472 |
def tensor2variable(tensor, device):
|
473 |
return tensor.requires_grad_(True).to(device)
|
474 |
|
475 |
|
476 |
-
class GMoF(torch.nn.Module):
|
477 |
-
|
478 |
-
def __init__(self, rho=1):
|
479 |
-
super(GMoF, self).__init__()
|
480 |
-
self.rho = rho
|
481 |
-
|
482 |
-
def extra_repr(self):
|
483 |
-
return "rho = {}".format(self.rho)
|
484 |
-
|
485 |
-
def forward(self, residual):
|
486 |
-
dist = torch.div(residual, residual + self.rho**2)
|
487 |
-
return self.rho**2 * dist
|
488 |
-
|
489 |
-
|
490 |
def mesh_edge_loss(meshes, target_length: float = 0.0):
|
491 |
"""
|
492 |
Computes mesh edge length regularization loss averaged across all meshes
|
@@ -508,10 +350,10 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
|
|
508 |
return torch.tensor([0.0], dtype=torch.float32, device=meshes.device, requires_grad=True)
|
509 |
|
510 |
N = len(meshes)
|
511 |
-
edges_packed = meshes.edges_packed()
|
512 |
-
verts_packed = meshes.verts_packed()
|
513 |
-
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()
|
514 |
-
num_edges_per_mesh = meshes.num_edges_per_mesh()
|
515 |
|
516 |
# Determine the weight for each edge based on the number of edges in the
|
517 |
# mesh it corresponds to.
|
@@ -531,99 +373,37 @@ def mesh_edge_loss(meshes, target_length: float = 0.0):
|
|
531 |
return loss_all
|
532 |
|
533 |
|
534 |
-
def
|
535 |
-
|
536 |
-
obj.export(obj_path)
|
537 |
-
ms = pymeshlab.MeshSet()
|
538 |
-
ms.load_new_mesh(obj_path)
|
539 |
-
# ms.meshing_decimation_quadric_edge_collapse(targetfacenum=100000)
|
540 |
-
ms.meshing_isotropic_explicit_remeshing(targetlen=pymeshlab.Percentage(0.5), adaptive=True)
|
541 |
-
ms.apply_coord_laplacian_smoothing()
|
542 |
-
ms.save_current_mesh(obj_path[:-4] + "_remesh.obj")
|
543 |
-
polished_mesh = trimesh.load_mesh(obj_path[:-4] + "_remesh.obj")
|
544 |
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
ms = pymeshlab.MeshSet()
|
551 |
-
ms.load_new_mesh(obj_path)
|
552 |
-
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
|
553 |
-
# ms.apply_coord_laplacian_smoothing()
|
554 |
-
ms.save_current_mesh(obj_path)
|
555 |
-
# ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
|
556 |
-
polished_mesh = trimesh.load_mesh(obj_path)
|
557 |
|
558 |
-
return
|
559 |
|
560 |
|
561 |
def poisson(mesh, obj_path, depth=10):
|
562 |
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
|
|
571 |
|
572 |
-
|
|
|
|
|
573 |
|
574 |
-
|
|
|
575 |
|
576 |
-
|
577 |
-
def get_mask(tensor, dim):
|
578 |
-
|
579 |
-
mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0
|
580 |
-
mask = mask.type_as(tensor)
|
581 |
-
|
582 |
-
return mask
|
583 |
-
|
584 |
-
|
585 |
-
def blend_rgb_norm(norms, data):
|
586 |
-
|
587 |
-
# norms [N, 3, res, res]
|
588 |
-
|
589 |
-
masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1)
|
590 |
-
norm_mask = F.interpolate(
|
591 |
-
torch.cat([norms, masks], dim=1).detach().cpu(),
|
592 |
-
size=data["uncrop_param"]["box_shape"],
|
593 |
-
mode="bilinear",
|
594 |
-
align_corners=False).permute(0, 2, 3, 1).numpy()
|
595 |
-
final = data["img_raw"]
|
596 |
-
|
597 |
-
for idx in range(len(norms)):
|
598 |
-
|
599 |
-
norm_pred = (norm_mask[idx, :, :, :3] + 1.0) * 255.0 / 2.0
|
600 |
-
mask_pred = np.repeat(norm_mask[idx, :, :, 3:4], 3, axis=-1)
|
601 |
-
|
602 |
-
norm_ori = unwrap(norm_pred, data["uncrop_param"], idx)
|
603 |
-
mask_ori = unwrap(mask_pred, data["uncrop_param"], idx)
|
604 |
-
|
605 |
-
final = final * (1.0 - mask_ori) + norm_ori * mask_ori
|
606 |
-
|
607 |
-
return final.astype(np.uint8)
|
608 |
-
|
609 |
-
|
610 |
-
def unwrap(image, uncrop_param, idx):
|
611 |
-
|
612 |
-
img_uncrop = uncrop(
|
613 |
-
image,
|
614 |
-
uncrop_param["center"][idx],
|
615 |
-
uncrop_param["scale"][idx],
|
616 |
-
uncrop_param["crop_shape"],
|
617 |
-
)
|
618 |
-
|
619 |
-
img_orig = cv2.warpAffine(
|
620 |
-
img_uncrop,
|
621 |
-
np.linalg.inv(uncrop_param["M"])[:2, :],
|
622 |
-
uncrop_param["ori_shape"][::-1],
|
623 |
-
flags=cv2.INTER_CUBIC,
|
624 |
-
)
|
625 |
-
|
626 |
-
return img_orig
|
627 |
|
628 |
|
629 |
# Losses to smooth / regularize the mesh shape
|
@@ -634,60 +414,7 @@ def update_mesh_shape_prior_losses(mesh, losses):
|
|
634 |
# mesh normal consistency
|
635 |
losses["nc"]["value"] = mesh_normal_consistency(mesh)
|
636 |
# mesh laplacian smoothing
|
637 |
-
losses["
|
638 |
-
|
639 |
-
|
640 |
-
def rename(old_dict, old_name, new_name):
|
641 |
-
new_dict = {}
|
642 |
-
for key, value in zip(old_dict.keys(), old_dict.values()):
|
643 |
-
new_key = key if key != old_name else new_name
|
644 |
-
new_dict[new_key] = old_dict[key]
|
645 |
-
return new_dict
|
646 |
-
|
647 |
-
|
648 |
-
def load_checkpoint(model, cfg):
|
649 |
-
|
650 |
-
model_dict = model.state_dict()
|
651 |
-
main_dict = {}
|
652 |
-
normal_dict = {}
|
653 |
-
|
654 |
-
device = torch.device(f"cuda:{cfg['test_gpus'][0]}")
|
655 |
-
|
656 |
-
if os.path.exists(cfg.resume_path) and cfg.resume_path.endswith("ckpt"):
|
657 |
-
main_dict = torch.load(cfg.resume_path, map_location=device)["state_dict"]
|
658 |
-
|
659 |
-
main_dict = {
|
660 |
-
k: v for k, v in main_dict.items() if k in model_dict and v.shape == model_dict[k].shape and
|
661 |
-
("reconEngine" not in k) and ("normal_filter" not in k) and ("voxelization" not in k)
|
662 |
-
}
|
663 |
-
print(colored(f"Resume MLP weights from {cfg.resume_path}", "green"))
|
664 |
-
|
665 |
-
if os.path.exists(cfg.normal_path) and cfg.normal_path.endswith("ckpt"):
|
666 |
-
normal_dict = torch.load(cfg.normal_path, map_location=device)["state_dict"]
|
667 |
-
|
668 |
-
for key in normal_dict.keys():
|
669 |
-
normal_dict = rename(normal_dict, key, key.replace("netG", "netG.normal_filter"))
|
670 |
-
|
671 |
-
normal_dict = {k: v for k, v in normal_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
|
672 |
-
print(colored(f"Resume normal model from {cfg.normal_path}", "green"))
|
673 |
-
|
674 |
-
model_dict.update(main_dict)
|
675 |
-
model_dict.update(normal_dict)
|
676 |
-
model.load_state_dict(model_dict)
|
677 |
-
|
678 |
-
model.netG = model.netG.to(device)
|
679 |
-
model.reconEngine = model.reconEngine.to(device)
|
680 |
-
|
681 |
-
model.netG.training = False
|
682 |
-
model.netG.eval()
|
683 |
-
|
684 |
-
del main_dict
|
685 |
-
del normal_dict
|
686 |
-
del model_dict
|
687 |
-
|
688 |
-
torch.cuda.empty_cache()
|
689 |
-
|
690 |
-
return model
|
691 |
|
692 |
|
693 |
def read_smpl_constants(folder):
|
@@ -706,8 +433,10 @@ def read_smpl_constants(folder):
|
|
706 |
smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
|
707 |
"""Load smpl faces & tetrahedrons"""
|
708 |
smpl_faces = np.loadtxt(os.path.join(folder, "faces.txt"), dtype=np.int32) - 1
|
709 |
-
smpl_face_code = (
|
710 |
-
|
|
|
|
|
711 |
smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1)
|
712 |
|
713 |
return_dict = {
|
@@ -720,19 +449,6 @@ def read_smpl_constants(folder):
|
|
720 |
return return_dict
|
721 |
|
722 |
|
723 |
-
def feat_select(feat, select):
|
724 |
-
|
725 |
-
# feat [B, featx2, N]
|
726 |
-
# select [B, 1, N]
|
727 |
-
# return [B, feat, N]
|
728 |
-
|
729 |
-
dim = feat.shape[1] // 2
|
730 |
-
idx = torch.tile((1 - select), (1, dim, 1)) * dim + torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select)
|
731 |
-
feat_select = torch.gather(feat, 1, idx.long())
|
732 |
-
|
733 |
-
return feat_select
|
734 |
-
|
735 |
-
|
736 |
def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel=1):
|
737 |
"""get the visibility of vertices
|
738 |
|
@@ -771,7 +487,9 @@ def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel
|
|
771 |
|
772 |
for idx in range(N_body):
|
773 |
Num_faces = len(faces[idx])
|
774 |
-
vis_vertices_id = torch.unique(
|
|
|
|
|
775 |
vis_mask[idx, vis_vertices_id] = 1.0
|
776 |
|
777 |
# print("------------------------\n")
|
@@ -825,7 +543,7 @@ def orthogonal(points, calibrations, transforms=None):
|
|
825 |
"""
|
826 |
rot = calibrations[:, :3, :3]
|
827 |
trans = calibrations[:, :3, 3:4]
|
828 |
-
pts = torch.baddbmm(trans, rot, points)
|
829 |
if transforms is not None:
|
830 |
scale = transforms[:2, :2]
|
831 |
shift = transforms[:2, 2:3]
|
@@ -925,37 +643,14 @@ def compute_normal_batch(vertices, faces):
|
|
925 |
return vert_norm
|
926 |
|
927 |
|
928 |
-
def calculate_mIoU(outputs, labels):
|
929 |
-
|
930 |
-
SMOOTH = 1e-6
|
931 |
-
|
932 |
-
outputs = outputs.int()
|
933 |
-
labels = labels.int()
|
934 |
-
|
935 |
-
intersection = ((outputs & labels).float().sum()) # Will be zero if Truth=0 or Prediction=0
|
936 |
-
union = (outputs | labels).float().sum() # Will be zzero if both are 0
|
937 |
-
|
938 |
-
iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
|
939 |
-
|
940 |
-
thresholded = (torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10) # This is equal to comparing with thresolds
|
941 |
-
|
942 |
-
return (thresholded.mean().detach().cpu().numpy()
|
943 |
-
) # Or thresholded.mean() if you are interested in average across the batch
|
944 |
-
|
945 |
-
|
946 |
-
def add_alpha(colors, alpha=0.7):
|
947 |
-
|
948 |
-
colors_pad = np.pad(colors, ((0, 0), (0, 1)), mode="constant", constant_values=alpha)
|
949 |
-
|
950 |
-
return colors_pad
|
951 |
-
|
952 |
-
|
953 |
def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
|
954 |
|
955 |
font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
|
956 |
font = ImageFont.truetype(font_path, 30)
|
957 |
grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), nrow=nrow, padding=0)
|
958 |
-
grid_img = Image.fromarray(
|
|
|
|
|
959 |
|
960 |
if False:
|
961 |
# add text
|
@@ -965,16 +660,20 @@ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
|
|
965 |
draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
|
966 |
|
967 |
if type == "smpl":
|
968 |
-
for col_id, col_txt in enumerate(
|
|
|
969 |
"image",
|
970 |
"smpl-norm(render)",
|
971 |
"cloth-norm(pred)",
|
972 |
"diff-norm",
|
973 |
"diff-mask",
|
974 |
-
|
|
|
975 |
draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
|
976 |
elif type == "cloth":
|
977 |
-
for col_id, col_txt in enumerate(
|
|
|
|
|
978 |
draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
|
979 |
for col_id, col_txt in enumerate(["0", "90", "180", "270"]):
|
980 |
draw.text(
|
@@ -996,12 +695,9 @@ def clean_mesh(verts, faces):
|
|
996 |
device = verts.device
|
997 |
|
998 |
mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), faces.detach().cpu().numpy())
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
mesh_clean = mesh_lst[comp_num.index(max(comp_num))]
|
1003 |
-
final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device)
|
1004 |
-
final_faces = torch.as_tensor(mesh_clean.faces).long().to(device)
|
1005 |
|
1006 |
return final_verts, final_faces
|
1007 |
|
|
|
14 |
#
|
15 |
# Contact: [email protected]
|
16 |
|
17 |
+
import os
|
18 |
import numpy as np
|
|
|
|
|
19 |
import torch
|
20 |
import torchvision
|
21 |
import trimesh
|
22 |
+
import open3d as o3d
|
23 |
+
import tinyobjloader
|
24 |
import os.path as osp
|
25 |
import _pickle as cPickle
|
26 |
+
from termcolor import colored
|
27 |
from scipy.spatial import cKDTree
|
28 |
|
29 |
from pytorch3d.structures import Meshes
|
30 |
import torch.nn.functional as F
|
31 |
import lib.smplx as smplx
|
32 |
+
from lib.common.render_utils import Pytorch3dRasterizer
|
33 |
from pytorch3d.renderer.mesh import rasterize_meshes
|
34 |
from PIL import Image, ImageFont, ImageDraw
|
35 |
from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency
|
|
|
36 |
|
|
|
|
|
37 |
|
38 |
+
class Format:
|
39 |
+
end = '\033[0m'
|
40 |
+
start = '\033[4m'
|
41 |
|
|
|
42 |
|
43 |
+
class SMPLX:
|
44 |
def __init__(self):
|
45 |
|
46 |
self.current_dir = osp.join(osp.dirname(__file__), "../../data/smpl_related")
|
|
|
55 |
|
56 |
self.smplx_eyeball_fid_path = osp.join(self.current_dir, "smpl_data/eyeball_fid.npy")
|
57 |
self.smplx_fill_mouth_fid_path = osp.join(self.current_dir, "smpl_data/fill_mouth_fid.npy")
|
58 |
+
self.smplx_flame_vid_path = osp.join(
|
59 |
+
self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy"
|
60 |
+
)
|
61 |
self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl")
|
62 |
self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy")
|
63 |
+
self.smplx_vertex_lmkid_path = osp.join(
|
64 |
+
self.current_dir, "smpl_data/smplx_vertex_lmkid.npy"
|
65 |
+
)
|
66 |
|
67 |
self.smplx_faces = np.load(self.smplx_faces_path)
|
68 |
self.smplx_verts = np.load(self.smplx_verts_path)
|
|
|
73 |
self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path)
|
74 |
self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path)
|
75 |
self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True)
|
76 |
+
self.smplx_mano_vid = np.concatenate(
|
77 |
+
[self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]]
|
78 |
+
)
|
79 |
self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True)
|
80 |
self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)]
|
81 |
|
82 |
# hands
|
83 |
+
self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
|
84 |
+
0, torch.tensor(self.smplx_mano_vid), 1.0
|
85 |
+
)
|
86 |
# face
|
87 |
+
self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
|
88 |
+
0, torch.tensor(self.smplx_front_flame_vid), 1.0
|
89 |
+
)
|
90 |
+
self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_(
|
91 |
+
0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0
|
92 |
+
)
|
93 |
|
94 |
self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, "rb"))
|
95 |
|
96 |
self.model_dir = osp.join(self.current_dir, "models")
|
97 |
self.tedra_dir = osp.join(self.current_dir, "../tedra_data")
|
98 |
|
99 |
+
self.ghum_smpl_pairs = torch.tensor(
|
100 |
+
[
|
101 |
+
(0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19),
|
102 |
+
(15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40),
|
103 |
+
(23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29),
|
104 |
+
(32, 32)
|
105 |
+
]
|
106 |
+
).long()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# smpl-smplx correspondence
|
109 |
self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73]
|
110 |
self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68]
|
111 |
+
self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist()
|
112 |
+
|
113 |
+
self.extra_joint_ids = np.array(
|
114 |
+
[
|
115 |
+
61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72,
|
116 |
+
74, 73
|
117 |
+
]
|
118 |
+
)
|
119 |
+
|
120 |
+
self.extra_joint_ids += 68
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist())
|
123 |
|
|
|
194 |
return smpl_mesh, smpl_joints
|
195 |
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
def apply_face_mask(mesh, face_mask):
|
198 |
|
199 |
mesh.update_faces(face_mask)
|
|
|
228 |
|
229 |
part_extractor = PointFeat(
|
230 |
torch.tensor(part_mesh.vertices).unsqueeze(0).to(device),
|
231 |
+
torch.tensor(part_mesh.faces).unsqueeze(0).to(device)
|
232 |
+
)
|
233 |
|
234 |
(part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device))
|
235 |
|
|
|
238 |
if region == "hand":
|
239 |
_, idx = smpl_tree.query(full_mesh.vertices, k=1)
|
240 |
full_lmkid = SMPL_container.smplx_vertex_lmkid[idx]
|
241 |
+
remove_mask = torch.logical_and(
|
242 |
+
remove_mask,
|
243 |
+
torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0)
|
244 |
+
)
|
245 |
|
246 |
elif region == "face":
|
247 |
_, idx = smpl_tree.query(full_mesh.vertices, k=5)
|
248 |
+
face_space_mask = torch.isin(
|
249 |
+
torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid)
|
250 |
+
)
|
251 |
+
remove_mask = torch.logical_and(
|
252 |
+
remove_mask,
|
253 |
+
face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0)
|
254 |
+
)
|
255 |
|
256 |
BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1)
|
257 |
full_mesh.update_faces(BNI_part_mask.detach().cpu())
|
|
|
263 |
return full_mesh
|
264 |
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
def obj_loader(path, with_uv=True):
|
267 |
# Create reader.
|
268 |
reader = tinyobjloader.ObjReader()
|
|
|
281 |
f_vt = tri[:, [2, 5, 8]]
|
282 |
|
283 |
if with_uv:
|
284 |
+
face_uvs = vt[f_vt].mean(axis=1) #[m, 2]
|
285 |
+
vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2]
|
286 |
vert_uvs[f_v.reshape(-1)] = vt[f_vt.reshape(-1)]
|
287 |
|
288 |
return v, f_v, vert_uvs, face_uvs
|
|
|
291 |
|
292 |
|
293 |
class HoppeMesh:
|
|
|
294 |
def __init__(self, verts, faces, uvs=None, texture=None):
|
295 |
"""
|
296 |
The HoppeSDF calculates signed distance towards a predefined oriented point cloud
|
|
|
315 |
- points: [n, 3]
|
316 |
- return: [n, 4] rgba
|
317 |
"""
|
318 |
+
triangles = self.verts[faces] #[n, 3, 3]
|
319 |
+
barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3]
|
320 |
+
vert_colors = self.vertex_colors[faces] #[n, 3, 4]
|
321 |
point_colors = torch.tensor((barycentric[:, :, None] * vert_colors).sum(axis=1)).float()
|
322 |
return point_colors
|
323 |
|
324 |
def triangles(self):
|
325 |
+
return self.verts[self.faces].numpy() #[n, 3, 3]
|
326 |
|
327 |
|
328 |
def tensor2variable(tensor, device):
|
329 |
return tensor.requires_grad_(True).to(device)
|
330 |
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
def mesh_edge_loss(meshes, target_length: float = 0.0):
|
333 |
"""
|
334 |
Computes mesh edge length regularization loss averaged across all meshes
|
|
|
350 |
return torch.tensor([0.0], dtype=torch.float32, device=meshes.device, requires_grad=True)
|
351 |
|
352 |
N = len(meshes)
|
353 |
+
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
|
354 |
+
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
355 |
+
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
|
356 |
+
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
|
357 |
|
358 |
# Determine the weight for each edge based on the number of edges in the
|
359 |
# mesh it corresponds to.
|
|
|
373 |
return loss_all
|
374 |
|
375 |
|
376 |
+
def remesh_laplacian(mesh, obj_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
+
mesh = mesh.simplify_quadratic_decimation(50000)
|
379 |
+
mesh = trimesh.smoothing.filter_humphrey(
|
380 |
+
mesh, alpha=0.1, beta=0.5, iterations=10, laplacian_operator=None
|
381 |
+
)
|
382 |
+
mesh.export(obj_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
+
return mesh
|
385 |
|
386 |
|
387 |
def poisson(mesh, obj_path, depth=10):
|
388 |
|
389 |
+
pcd_path = obj_path[:-4] + ".ply"
|
390 |
+
assert (mesh.vertex_normals.shape[1] == 3)
|
391 |
+
mesh.export(pcd_path)
|
392 |
+
pcl = o3d.io.read_point_cloud(pcd_path)
|
393 |
+
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm:
|
394 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
|
395 |
+
pcl, depth=depth, n_threads=-1
|
396 |
+
)
|
397 |
+
print(colored(f"\n Poisson completion to {Format.start} {obj_path} {Format.end}", "yellow"))
|
398 |
|
399 |
+
# only keep the largest component
|
400 |
+
largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles)))
|
401 |
+
largest_mesh.export(obj_path)
|
402 |
|
403 |
+
# mesh decimation for faster rendering
|
404 |
+
low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000)
|
405 |
|
406 |
+
return low_res_mesh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
|
409 |
# Losses to smooth / regularize the mesh shape
|
|
|
414 |
# mesh normal consistency
|
415 |
losses["nc"]["value"] = mesh_normal_consistency(mesh)
|
416 |
# mesh laplacian smoothing
|
417 |
+
losses["lapla"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
|
420 |
def read_smpl_constants(folder):
|
|
|
433 |
smpl_vertex_code = np.float32(np.copy(smpl_vtx_std))
|
434 |
"""Load smpl faces & tetrahedrons"""
|
435 |
smpl_faces = np.loadtxt(os.path.join(folder, "faces.txt"), dtype=np.int32) - 1
|
436 |
+
smpl_face_code = (
|
437 |
+
smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] +
|
438 |
+
smpl_vertex_code[smpl_faces[:, 2]]
|
439 |
+
) / 3.0
|
440 |
smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1)
|
441 |
|
442 |
return_dict = {
|
|
|
449 |
return return_dict
|
450 |
|
451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel=1):
|
453 |
"""get the visibility of vertices
|
454 |
|
|
|
487 |
|
488 |
for idx in range(N_body):
|
489 |
Num_faces = len(faces[idx])
|
490 |
+
vis_vertices_id = torch.unique(
|
491 |
+
faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :]
|
492 |
+
)
|
493 |
vis_mask[idx, vis_vertices_id] = 1.0
|
494 |
|
495 |
# print("------------------------\n")
|
|
|
543 |
"""
|
544 |
rot = calibrations[:, :3, :3]
|
545 |
trans = calibrations[:, :3, 3:4]
|
546 |
+
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
547 |
if transforms is not None:
|
548 |
scale = transforms[:2, :2]
|
549 |
shift = transforms[:2, 2:3]
|
|
|
643 |
return vert_norm
|
644 |
|
645 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"):
|
647 |
|
648 |
font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf")
|
649 |
font = ImageFont.truetype(font_path, 30)
|
650 |
grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), nrow=nrow, padding=0)
|
651 |
+
grid_img = Image.fromarray(
|
652 |
+
((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8)
|
653 |
+
)
|
654 |
|
655 |
if False:
|
656 |
# add text
|
|
|
660 |
draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font)
|
661 |
|
662 |
if type == "smpl":
|
663 |
+
for col_id, col_txt in enumerate(
|
664 |
+
[
|
665 |
"image",
|
666 |
"smpl-norm(render)",
|
667 |
"cloth-norm(pred)",
|
668 |
"diff-norm",
|
669 |
"diff-mask",
|
670 |
+
]
|
671 |
+
):
|
672 |
draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
|
673 |
elif type == "cloth":
|
674 |
+
for col_id, col_txt in enumerate(
|
675 |
+
["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]
|
676 |
+
):
|
677 |
draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font)
|
678 |
for col_id, col_txt in enumerate(["0", "90", "180", "270"]):
|
679 |
draw.text(
|
|
|
695 |
device = verts.device
|
696 |
|
697 |
mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), faces.detach().cpu().numpy())
|
698 |
+
largest_mesh = keep_largest(mesh_lst)
|
699 |
+
final_verts = torch.as_tensor(largest_mesh.vertices).float().to(device)
|
700 |
+
final_faces = torch.as_tensor(largest_mesh.faces).long().to(device)
|
|
|
|
|
|
|
701 |
|
702 |
return final_verts, final_faces
|
703 |
|
lib/net/BasePIFuNet.py
CHANGED
@@ -21,11 +21,10 @@ from .geometry import index, orthogonal, perspective
|
|
21 |
|
22 |
|
23 |
class BasePIFuNet(pl.LightningModule):
|
24 |
-
|
25 |
def __init__(
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
):
|
30 |
"""
|
31 |
:param projection_mode:
|
|
|
21 |
|
22 |
|
23 |
class BasePIFuNet(pl.LightningModule):
|
|
|
24 |
def __init__(
|
25 |
+
self,
|
26 |
+
projection_mode="orthogonal",
|
27 |
+
error_term=nn.MSELoss(),
|
28 |
):
|
29 |
"""
|
30 |
:param projection_mode:
|
lib/net/Discriminator.py
CHANGED
@@ -9,17 +9,18 @@ from lib.torch_utils.ops.native_ops import FusedLeakyReLU, fused_leaky_relu, upf
|
|
9 |
|
10 |
|
11 |
class DiscriminatorHead(nn.Module):
|
12 |
-
|
13 |
def __init__(self, in_channel, disc_stddev=False):
|
14 |
super().__init__()
|
15 |
|
16 |
self.disc_stddev = disc_stddev
|
17 |
stddev_dim = 1 if disc_stddev else 0
|
18 |
|
19 |
-
self.conv_stddev = ConvLayer2d(
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
self.final_linear = nn.Sequential(
|
25 |
nn.Flatten(),
|
@@ -32,8 +33,8 @@ class DiscriminatorHead(nn.Module):
|
|
32 |
inv_perm = torch.argsort(perm)
|
33 |
|
34 |
batch, channel, height, width = x.shape
|
35 |
-
x = x[
|
36 |
-
|
37 |
|
38 |
group = min(batch, stddev_group)
|
39 |
stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
|
@@ -41,7 +42,7 @@ class DiscriminatorHead(nn.Module):
|
|
41 |
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
42 |
stddev = stddev.repeat(group, 1, height, width)
|
43 |
|
44 |
-
stddev = stddev[inv_perm]
|
45 |
x = x[inv_perm]
|
46 |
|
47 |
out = torch.cat([x, stddev], 1)
|
@@ -56,7 +57,6 @@ class DiscriminatorHead(nn.Module):
|
|
56 |
|
57 |
|
58 |
class ConvDecoder(nn.Module):
|
59 |
-
|
60 |
def __init__(self, in_channel, out_channel, in_res, out_res):
|
61 |
super().__init__()
|
62 |
|
@@ -68,20 +68,22 @@ class ConvDecoder(nn.Module):
|
|
68 |
for i in range(log_size_in, log_size_out):
|
69 |
out_ch = in_ch // 2
|
70 |
self.layers.append(
|
71 |
-
ConvLayer2d(
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
77 |
in_ch = out_ch
|
78 |
|
79 |
self.layers.append(
|
80 |
-
ConvLayer2d(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
activate=False))
|
85 |
self.layers = nn.Sequential(*self.layers)
|
86 |
|
87 |
def forward(self, x):
|
@@ -89,7 +91,6 @@ class ConvDecoder(nn.Module):
|
|
89 |
|
90 |
|
91 |
class StyleDiscriminator(nn.Module):
|
92 |
-
|
93 |
def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs):
|
94 |
super().__init__()
|
95 |
|
@@ -104,7 +105,8 @@ class StyleDiscriminator(nn.Module):
|
|
104 |
for i in range(log_size_in, log_size_out, -1):
|
105 |
out_channels = int(min(in_channels * 2, ch_max))
|
106 |
self.layers.append(
|
107 |
-
ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)
|
|
|
108 |
in_channels = out_channels
|
109 |
self.layers = nn.Sequential(*self.layers)
|
110 |
|
@@ -147,7 +149,6 @@ class Blur(nn.Module):
|
|
147 |
Upsample factor.
|
148 |
|
149 |
"""
|
150 |
-
|
151 |
def __init__(self, kernel, pad, upsample_factor=1):
|
152 |
super().__init__()
|
153 |
|
@@ -177,7 +178,6 @@ class Upsample(nn.Module):
|
|
177 |
Upsampling factor.
|
178 |
|
179 |
"""
|
180 |
-
|
181 |
def __init__(self, kernel=[1, 3, 3, 1], factor=2):
|
182 |
super().__init__()
|
183 |
|
@@ -208,7 +208,6 @@ class Downsample(nn.Module):
|
|
208 |
Downsampling factor.
|
209 |
|
210 |
"""
|
211 |
-
|
212 |
def __init__(self, kernel=[1, 3, 3, 1], factor=2):
|
213 |
super().__init__()
|
214 |
|
@@ -250,7 +249,6 @@ class EqualLinear(nn.Module):
|
|
250 |
Apply leakyReLU activation.
|
251 |
|
252 |
"""
|
253 |
-
|
254 |
def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
|
255 |
super().__init__()
|
256 |
|
@@ -300,7 +298,6 @@ class EqualConv2d(nn.Module):
|
|
300 |
Use bias term.
|
301 |
|
302 |
"""
|
303 |
-
|
304 |
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
305 |
super().__init__()
|
306 |
|
@@ -316,16 +313,20 @@ class EqualConv2d(nn.Module):
|
|
316 |
self.bias = None
|
317 |
|
318 |
def forward(self, input):
|
319 |
-
out = F.conv2d(
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
324 |
return out
|
325 |
|
326 |
def __repr__(self):
|
327 |
-
return (
|
328 |
-
|
|
|
|
|
329 |
|
330 |
|
331 |
class EqualConvTranspose2d(nn.Module):
|
@@ -353,15 +354,16 @@ class EqualConvTranspose2d(nn.Module):
|
|
353 |
Use bias term.
|
354 |
|
355 |
"""
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
365 |
super().__init__()
|
366 |
|
367 |
self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
|
@@ -388,12 +390,13 @@ class EqualConvTranspose2d(nn.Module):
|
|
388 |
return out
|
389 |
|
390 |
def __repr__(self):
|
391 |
-
return (
|
392 |
-
|
|
|
|
|
393 |
|
394 |
|
395 |
class ConvLayer2d(nn.Sequential):
|
396 |
-
|
397 |
def __init__(
|
398 |
self,
|
399 |
in_channel,
|
@@ -415,12 +418,15 @@ class ConvLayer2d(nn.Sequential):
|
|
415 |
pad1 = p // 2 + 1
|
416 |
|
417 |
layers.append(
|
418 |
-
EqualConvTranspose2d(
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
424 |
layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))
|
425 |
|
426 |
if downsample:
|
@@ -431,23 +437,29 @@ class ConvLayer2d(nn.Sequential):
|
|
431 |
|
432 |
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
433 |
layers.append(
|
434 |
-
EqualConv2d(
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
440 |
|
441 |
if (not downsample) and (not upsample):
|
442 |
padding = kernel_size // 2
|
443 |
|
444 |
layers.append(
|
445 |
-
EqualConv2d(
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
451 |
|
452 |
if activate:
|
453 |
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
@@ -472,7 +484,6 @@ class ConvResBlock2d(nn.Module):
|
|
472 |
Apply downsampling via strided convolution in the second conv.
|
473 |
|
474 |
"""
|
475 |
-
|
476 |
def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
|
477 |
super().__init__()
|
478 |
|
|
|
9 |
|
10 |
|
11 |
class DiscriminatorHead(nn.Module):
|
|
|
12 |
def __init__(self, in_channel, disc_stddev=False):
|
13 |
super().__init__()
|
14 |
|
15 |
self.disc_stddev = disc_stddev
|
16 |
stddev_dim = 1 if disc_stddev else 0
|
17 |
|
18 |
+
self.conv_stddev = ConvLayer2d(
|
19 |
+
in_channel=in_channel + stddev_dim,
|
20 |
+
out_channel=in_channel,
|
21 |
+
kernel_size=3,
|
22 |
+
activate=True
|
23 |
+
)
|
24 |
|
25 |
self.final_linear = nn.Sequential(
|
26 |
nn.Flatten(),
|
|
|
33 |
inv_perm = torch.argsort(perm)
|
34 |
|
35 |
batch, channel, height, width = x.shape
|
36 |
+
x = x[perm
|
37 |
+
] # shuffle inputs so that all views in a single trajectory don't get put together
|
38 |
|
39 |
group = min(batch, stddev_group)
|
40 |
stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
|
|
|
42 |
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
43 |
stddev = stddev.repeat(group, 1, height, width)
|
44 |
|
45 |
+
stddev = stddev[inv_perm] # reorder inputs
|
46 |
x = x[inv_perm]
|
47 |
|
48 |
out = torch.cat([x, stddev], 1)
|
|
|
57 |
|
58 |
|
59 |
class ConvDecoder(nn.Module):
|
|
|
60 |
def __init__(self, in_channel, out_channel, in_res, out_res):
|
61 |
super().__init__()
|
62 |
|
|
|
68 |
for i in range(log_size_in, log_size_out):
|
69 |
out_ch = in_ch // 2
|
70 |
self.layers.append(
|
71 |
+
ConvLayer2d(
|
72 |
+
in_channel=in_ch,
|
73 |
+
out_channel=out_ch,
|
74 |
+
kernel_size=3,
|
75 |
+
upsample=True,
|
76 |
+
bias=True,
|
77 |
+
activate=True
|
78 |
+
)
|
79 |
+
)
|
80 |
in_ch = out_ch
|
81 |
|
82 |
self.layers.append(
|
83 |
+
ConvLayer2d(
|
84 |
+
in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False
|
85 |
+
)
|
86 |
+
)
|
|
|
87 |
self.layers = nn.Sequential(*self.layers)
|
88 |
|
89 |
def forward(self, x):
|
|
|
91 |
|
92 |
|
93 |
class StyleDiscriminator(nn.Module):
|
|
|
94 |
def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs):
|
95 |
super().__init__()
|
96 |
|
|
|
105 |
for i in range(log_size_in, log_size_out, -1):
|
106 |
out_channels = int(min(in_channels * 2, ch_max))
|
107 |
self.layers.append(
|
108 |
+
ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)
|
109 |
+
)
|
110 |
in_channels = out_channels
|
111 |
self.layers = nn.Sequential(*self.layers)
|
112 |
|
|
|
149 |
Upsample factor.
|
150 |
|
151 |
"""
|
|
|
152 |
def __init__(self, kernel, pad, upsample_factor=1):
|
153 |
super().__init__()
|
154 |
|
|
|
178 |
Upsampling factor.
|
179 |
|
180 |
"""
|
|
|
181 |
def __init__(self, kernel=[1, 3, 3, 1], factor=2):
|
182 |
super().__init__()
|
183 |
|
|
|
208 |
Downsampling factor.
|
209 |
|
210 |
"""
|
|
|
211 |
def __init__(self, kernel=[1, 3, 3, 1], factor=2):
|
212 |
super().__init__()
|
213 |
|
|
|
249 |
Apply leakyReLU activation.
|
250 |
|
251 |
"""
|
|
|
252 |
def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
|
253 |
super().__init__()
|
254 |
|
|
|
298 |
Use bias term.
|
299 |
|
300 |
"""
|
|
|
301 |
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
302 |
super().__init__()
|
303 |
|
|
|
313 |
self.bias = None
|
314 |
|
315 |
def forward(self, input):
|
316 |
+
out = F.conv2d(
|
317 |
+
input,
|
318 |
+
self.weight * self.scale,
|
319 |
+
bias=self.bias,
|
320 |
+
stride=self.stride,
|
321 |
+
padding=self.padding
|
322 |
+
)
|
323 |
return out
|
324 |
|
325 |
def __repr__(self):
|
326 |
+
return (
|
327 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
328 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
329 |
+
)
|
330 |
|
331 |
|
332 |
class EqualConvTranspose2d(nn.Module):
|
|
|
354 |
Use bias term.
|
355 |
|
356 |
"""
|
357 |
+
def __init__(
|
358 |
+
self,
|
359 |
+
in_channel,
|
360 |
+
out_channel,
|
361 |
+
kernel_size,
|
362 |
+
stride=1,
|
363 |
+
padding=0,
|
364 |
+
output_padding=0,
|
365 |
+
bias=True
|
366 |
+
):
|
367 |
super().__init__()
|
368 |
|
369 |
self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
|
|
|
390 |
return out
|
391 |
|
392 |
def __repr__(self):
|
393 |
+
return (
|
394 |
+
f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
|
395 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
396 |
+
)
|
397 |
|
398 |
|
399 |
class ConvLayer2d(nn.Sequential):
|
|
|
400 |
def __init__(
|
401 |
self,
|
402 |
in_channel,
|
|
|
418 |
pad1 = p // 2 + 1
|
419 |
|
420 |
layers.append(
|
421 |
+
EqualConvTranspose2d(
|
422 |
+
in_channel,
|
423 |
+
out_channel,
|
424 |
+
kernel_size,
|
425 |
+
padding=0,
|
426 |
+
stride=2,
|
427 |
+
bias=bias and not activate
|
428 |
+
)
|
429 |
+
)
|
430 |
layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))
|
431 |
|
432 |
if downsample:
|
|
|
437 |
|
438 |
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
439 |
layers.append(
|
440 |
+
EqualConv2d(
|
441 |
+
in_channel,
|
442 |
+
out_channel,
|
443 |
+
kernel_size,
|
444 |
+
padding=0,
|
445 |
+
stride=2,
|
446 |
+
bias=bias and not activate
|
447 |
+
)
|
448 |
+
)
|
449 |
|
450 |
if (not downsample) and (not upsample):
|
451 |
padding = kernel_size // 2
|
452 |
|
453 |
layers.append(
|
454 |
+
EqualConv2d(
|
455 |
+
in_channel,
|
456 |
+
out_channel,
|
457 |
+
kernel_size,
|
458 |
+
padding=padding,
|
459 |
+
stride=1,
|
460 |
+
bias=bias and not activate
|
461 |
+
)
|
462 |
+
)
|
463 |
|
464 |
if activate:
|
465 |
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
|
|
484 |
Apply downsampling via strided convolution in the second conv.
|
485 |
|
486 |
"""
|
|
|
487 |
def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
|
488 |
super().__init__()
|
489 |
|
lib/net/FBNet.py
CHANGED
@@ -51,17 +51,17 @@ def get_norm_layer(norm_type="instance"):
|
|
51 |
|
52 |
|
53 |
def define_G(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
):
|
66 |
norm_layer = get_norm_layer(norm_type=norm)
|
67 |
if netG == "global":
|
@@ -97,17 +97,20 @@ def define_G(
|
|
97 |
return netG
|
98 |
|
99 |
|
100 |
-
def define_D(
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
norm_layer = get_norm_layer(norm_type=norm)
|
109 |
-
netD = MultiscaleDiscriminator(
|
110 |
-
|
|
|
111 |
if len(gpu_ids) > 0:
|
112 |
assert (torch.cuda.is_available())
|
113 |
netD.cuda(gpu_ids[0])
|
@@ -129,7 +132,6 @@ def print_network(net):
|
|
129 |
# Generator
|
130 |
##############################################################################
|
131 |
class LocalEnhancer(pl.LightningModule):
|
132 |
-
|
133 |
def __init__(
|
134 |
self,
|
135 |
input_nc,
|
@@ -155,8 +157,9 @@ class LocalEnhancer(pl.LightningModule):
|
|
155 |
n_blocks_global,
|
156 |
norm_layer,
|
157 |
).model
|
158 |
-
model_global = [
|
159 |
-
|
|
|
160 |
self.model = nn.Sequential(*model_global)
|
161 |
|
162 |
###### local enhancer layers #####
|
@@ -224,17 +227,16 @@ class LocalEnhancer(pl.LightningModule):
|
|
224 |
|
225 |
|
226 |
class GlobalGenerator(pl.LightningModule):
|
227 |
-
|
228 |
def __init__(
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
):
|
239 |
assert n_blocks >= 0
|
240 |
super(GlobalGenerator, self).__init__()
|
@@ -296,42 +298,49 @@ class GlobalGenerator(pl.LightningModule):
|
|
296 |
|
297 |
# Defines the PatchGAN discriminator with the specified arguments.
|
298 |
class NLayerDiscriminator(nn.Module):
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
307 |
super(NLayerDiscriminator, self).__init__()
|
308 |
self.getIntermFeat = getIntermFeat
|
309 |
self.n_layers = n_layers
|
310 |
|
311 |
kw = 4
|
312 |
padw = int(np.ceil((kw - 1.0) / 2))
|
313 |
-
sequence = [
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
|
|
317 |
|
318 |
nf = ndf
|
319 |
for n in range(1, n_layers):
|
320 |
nf_prev = nf
|
321 |
nf = min(nf * 2, 512)
|
322 |
-
sequence += [
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
|
|
|
|
327 |
|
328 |
nf_prev = nf
|
329 |
nf = min(nf * 2, 512)
|
330 |
-
sequence += [
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
|
|
335 |
|
336 |
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
337 |
|
@@ -359,27 +368,30 @@ class NLayerDiscriminator(nn.Module):
|
|
359 |
|
360 |
|
361 |
class MultiscaleDiscriminator(pl.LightningModule):
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
371 |
super(MultiscaleDiscriminator, self).__init__()
|
372 |
self.num_D = num_D
|
373 |
self.n_layers = n_layers
|
374 |
self.getIntermFeat = getIntermFeat
|
375 |
|
376 |
for i in range(num_D):
|
377 |
-
netD = NLayerDiscriminator(
|
378 |
-
|
|
|
379 |
if getIntermFeat:
|
380 |
for j in range(n_layers + 2):
|
381 |
-
setattr(
|
382 |
-
|
|
|
383 |
else:
|
384 |
setattr(self, 'layer' + str(i), netD.model)
|
385 |
|
@@ -414,11 +426,11 @@ class MultiscaleDiscriminator(pl.LightningModule):
|
|
414 |
|
415 |
# Define a resnet block
|
416 |
class ResnetBlock(pl.LightningModule):
|
417 |
-
|
418 |
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
419 |
super(ResnetBlock, self).__init__()
|
420 |
-
self.conv_block = self.build_conv_block(
|
421 |
-
|
|
|
422 |
|
423 |
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
424 |
conv_block = []
|
@@ -459,7 +471,6 @@ class ResnetBlock(pl.LightningModule):
|
|
459 |
|
460 |
|
461 |
class Encoder(pl.LightningModule):
|
462 |
-
|
463 |
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
|
464 |
super(Encoder, self).__init__()
|
465 |
self.output_nc = output_nc
|
@@ -510,18 +521,17 @@ class Encoder(pl.LightningModule):
|
|
510 |
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
511 |
for i in inst_list:
|
512 |
for b in range(input.size()[0]):
|
513 |
-
indices = (inst[b:b + 1] == int(i)).nonzero()
|
514 |
for j in range(self.output_nc):
|
515 |
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
|
516 |
-
indices[:, 3],]
|
517 |
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
518 |
outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
|
519 |
-
indices[:, 3],] = mean_feat
|
520 |
return outputs_mean
|
521 |
|
522 |
|
523 |
class Vgg19(nn.Module):
|
524 |
-
|
525 |
def __init__(self, requires_grad=False):
|
526 |
super(Vgg19, self).__init__()
|
527 |
vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
|
@@ -555,7 +565,6 @@ class Vgg19(nn.Module):
|
|
555 |
|
556 |
|
557 |
class VGG19FeatLayer(nn.Module):
|
558 |
-
|
559 |
def __init__(self):
|
560 |
super(VGG19FeatLayer, self).__init__()
|
561 |
self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval()
|
@@ -593,7 +602,6 @@ class VGG19FeatLayer(nn.Module):
|
|
593 |
|
594 |
|
595 |
class VGGLoss(pl.LightningModule):
|
596 |
-
|
597 |
def __init__(self):
|
598 |
super(VGGLoss, self).__init__()
|
599 |
self.vgg = Vgg19().eval()
|
@@ -609,11 +617,7 @@ class VGGLoss(pl.LightningModule):
|
|
609 |
|
610 |
|
611 |
class GANLoss(pl.LightningModule):
|
612 |
-
|
613 |
-
def __init__(self,
|
614 |
-
use_lsgan=True,
|
615 |
-
target_real_label=1.0,
|
616 |
-
target_fake_label=0.0):
|
617 |
super(GANLoss, self).__init__()
|
618 |
self.real_label = target_real_label
|
619 |
self.fake_label = target_fake_label
|
@@ -628,16 +632,18 @@ class GANLoss(pl.LightningModule):
|
|
628 |
def get_target_tensor(self, input, target_is_real):
|
629 |
target_tensor = None
|
630 |
if target_is_real:
|
631 |
-
create_label = (
|
632 |
-
|
|
|
633 |
if create_label:
|
634 |
real_tensor = self.tensor(input.size()).fill_(self.real_label)
|
635 |
self.real_label_var = real_tensor
|
636 |
self.real_label_var.requires_grad = False
|
637 |
target_tensor = self.real_label_var
|
638 |
else:
|
639 |
-
create_label = (
|
640 |
-
|
|
|
641 |
if create_label:
|
642 |
fake_tensor = self.tensor(input.size()).fill_(self.fake_label)
|
643 |
self.fake_label_var = fake_tensor
|
@@ -659,7 +665,6 @@ class GANLoss(pl.LightningModule):
|
|
659 |
|
660 |
|
661 |
class IDMRFLoss(pl.LightningModule):
|
662 |
-
|
663 |
def __init__(self, featlayer=VGG19FeatLayer):
|
664 |
super(IDMRFLoss, self).__init__()
|
665 |
self.featlayer = featlayer()
|
@@ -678,7 +683,8 @@ class IDMRFLoss(pl.LightningModule):
|
|
678 |
patch_size = 1
|
679 |
patch_stride = 1
|
680 |
patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(
|
681 |
-
3, patch_size, patch_stride
|
|
|
682 |
self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
|
683 |
dims = self.patches_OIHW.size()
|
684 |
self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
|
@@ -743,7 +749,8 @@ class IDMRFLoss(pl.LightningModule):
|
|
743 |
self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer])
|
744 |
for layer in self.feat_content_layers
|
745 |
]
|
746 |
-
self.content_loss = functools.reduce(
|
747 |
-
|
|
|
748 |
|
749 |
return self.style_loss + self.content_loss
|
|
|
51 |
|
52 |
|
53 |
def define_G(
|
54 |
+
input_nc,
|
55 |
+
output_nc,
|
56 |
+
ngf,
|
57 |
+
netG,
|
58 |
+
n_downsample_global=3,
|
59 |
+
n_blocks_global=9,
|
60 |
+
n_local_enhancers=1,
|
61 |
+
n_blocks_local=3,
|
62 |
+
norm="instance",
|
63 |
+
gpu_ids=[],
|
64 |
+
last_op=nn.Tanh(),
|
65 |
):
|
66 |
norm_layer = get_norm_layer(norm_type=norm)
|
67 |
if netG == "global":
|
|
|
97 |
return netG
|
98 |
|
99 |
|
100 |
+
def define_D(
|
101 |
+
input_nc,
|
102 |
+
ndf,
|
103 |
+
n_layers_D,
|
104 |
+
norm='instance',
|
105 |
+
use_sigmoid=False,
|
106 |
+
num_D=1,
|
107 |
+
getIntermFeat=False,
|
108 |
+
gpu_ids=[]
|
109 |
+
):
|
110 |
norm_layer = get_norm_layer(norm_type=norm)
|
111 |
+
netD = MultiscaleDiscriminator(
|
112 |
+
input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat
|
113 |
+
)
|
114 |
if len(gpu_ids) > 0:
|
115 |
assert (torch.cuda.is_available())
|
116 |
netD.cuda(gpu_ids[0])
|
|
|
132 |
# Generator
|
133 |
##############################################################################
|
134 |
class LocalEnhancer(pl.LightningModule):
|
|
|
135 |
def __init__(
|
136 |
self,
|
137 |
input_nc,
|
|
|
157 |
n_blocks_global,
|
158 |
norm_layer,
|
159 |
).model
|
160 |
+
model_global = [
|
161 |
+
model_global[i] for i in range(len(model_global) - 3)
|
162 |
+
] # get rid of final convolution layers
|
163 |
self.model = nn.Sequential(*model_global)
|
164 |
|
165 |
###### local enhancer layers #####
|
|
|
227 |
|
228 |
|
229 |
class GlobalGenerator(pl.LightningModule):
|
|
|
230 |
def __init__(
|
231 |
+
self,
|
232 |
+
input_nc,
|
233 |
+
output_nc,
|
234 |
+
ngf=64,
|
235 |
+
n_downsampling=3,
|
236 |
+
n_blocks=9,
|
237 |
+
norm_layer=nn.BatchNorm2d,
|
238 |
+
padding_type="reflect",
|
239 |
+
last_op=nn.Tanh(),
|
240 |
):
|
241 |
assert n_blocks >= 0
|
242 |
super(GlobalGenerator, self).__init__()
|
|
|
298 |
|
299 |
# Defines the PatchGAN discriminator with the specified arguments.
|
300 |
class NLayerDiscriminator(nn.Module):
|
301 |
+
def __init__(
|
302 |
+
self,
|
303 |
+
input_nc,
|
304 |
+
ndf=64,
|
305 |
+
n_layers=3,
|
306 |
+
norm_layer=nn.BatchNorm2d,
|
307 |
+
use_sigmoid=False,
|
308 |
+
getIntermFeat=False
|
309 |
+
):
|
310 |
super(NLayerDiscriminator, self).__init__()
|
311 |
self.getIntermFeat = getIntermFeat
|
312 |
self.n_layers = n_layers
|
313 |
|
314 |
kw = 4
|
315 |
padw = int(np.ceil((kw - 1.0) / 2))
|
316 |
+
sequence = [
|
317 |
+
[
|
318 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
319 |
+
nn.LeakyReLU(0.2, True)
|
320 |
+
]
|
321 |
+
]
|
322 |
|
323 |
nf = ndf
|
324 |
for n in range(1, n_layers):
|
325 |
nf_prev = nf
|
326 |
nf = min(nf * 2, 512)
|
327 |
+
sequence += [
|
328 |
+
[
|
329 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
330 |
+
norm_layer(nf),
|
331 |
+
nn.LeakyReLU(0.2, True)
|
332 |
+
]
|
333 |
+
]
|
334 |
|
335 |
nf_prev = nf
|
336 |
nf = min(nf * 2, 512)
|
337 |
+
sequence += [
|
338 |
+
[
|
339 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
340 |
+
norm_layer(nf),
|
341 |
+
nn.LeakyReLU(0.2, True)
|
342 |
+
]
|
343 |
+
]
|
344 |
|
345 |
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
346 |
|
|
|
368 |
|
369 |
|
370 |
class MultiscaleDiscriminator(pl.LightningModule):
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
input_nc,
|
374 |
+
ndf=64,
|
375 |
+
n_layers=3,
|
376 |
+
norm_layer=nn.BatchNorm2d,
|
377 |
+
use_sigmoid=False,
|
378 |
+
num_D=3,
|
379 |
+
getIntermFeat=False
|
380 |
+
):
|
381 |
super(MultiscaleDiscriminator, self).__init__()
|
382 |
self.num_D = num_D
|
383 |
self.n_layers = n_layers
|
384 |
self.getIntermFeat = getIntermFeat
|
385 |
|
386 |
for i in range(num_D):
|
387 |
+
netD = NLayerDiscriminator(
|
388 |
+
input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat
|
389 |
+
)
|
390 |
if getIntermFeat:
|
391 |
for j in range(n_layers + 2):
|
392 |
+
setattr(
|
393 |
+
self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))
|
394 |
+
)
|
395 |
else:
|
396 |
setattr(self, 'layer' + str(i), netD.model)
|
397 |
|
|
|
426 |
|
427 |
# Define a resnet block
|
428 |
class ResnetBlock(pl.LightningModule):
|
|
|
429 |
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
430 |
super(ResnetBlock, self).__init__()
|
431 |
+
self.conv_block = self.build_conv_block(
|
432 |
+
dim, padding_type, norm_layer, activation, use_dropout
|
433 |
+
)
|
434 |
|
435 |
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
436 |
conv_block = []
|
|
|
471 |
|
472 |
|
473 |
class Encoder(pl.LightningModule):
|
|
|
474 |
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
|
475 |
super(Encoder, self).__init__()
|
476 |
self.output_nc = output_nc
|
|
|
521 |
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
522 |
for i in inst_list:
|
523 |
for b in range(input.size()[0]):
|
524 |
+
indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4
|
525 |
for j in range(self.output_nc):
|
526 |
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
|
527 |
+
indices[:, 3], ]
|
528 |
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
529 |
outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2],
|
530 |
+
indices[:, 3], ] = mean_feat
|
531 |
return outputs_mean
|
532 |
|
533 |
|
534 |
class Vgg19(nn.Module):
|
|
|
535 |
def __init__(self, requires_grad=False):
|
536 |
super(Vgg19, self).__init__()
|
537 |
vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
|
|
|
565 |
|
566 |
|
567 |
class VGG19FeatLayer(nn.Module):
|
|
|
568 |
def __init__(self):
|
569 |
super(VGG19FeatLayer, self).__init__()
|
570 |
self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval()
|
|
|
602 |
|
603 |
|
604 |
class VGGLoss(pl.LightningModule):
|
|
|
605 |
def __init__(self):
|
606 |
super(VGGLoss, self).__init__()
|
607 |
self.vgg = Vgg19().eval()
|
|
|
617 |
|
618 |
|
619 |
class GANLoss(pl.LightningModule):
|
620 |
+
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
|
|
|
|
|
|
|
|
|
621 |
super(GANLoss, self).__init__()
|
622 |
self.real_label = target_real_label
|
623 |
self.fake_label = target_fake_label
|
|
|
632 |
def get_target_tensor(self, input, target_is_real):
|
633 |
target_tensor = None
|
634 |
if target_is_real:
|
635 |
+
create_label = (
|
636 |
+
(self.real_label_var is None) or (self.real_label_var.numel() != input.numel())
|
637 |
+
)
|
638 |
if create_label:
|
639 |
real_tensor = self.tensor(input.size()).fill_(self.real_label)
|
640 |
self.real_label_var = real_tensor
|
641 |
self.real_label_var.requires_grad = False
|
642 |
target_tensor = self.real_label_var
|
643 |
else:
|
644 |
+
create_label = (
|
645 |
+
(self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())
|
646 |
+
)
|
647 |
if create_label:
|
648 |
fake_tensor = self.tensor(input.size()).fill_(self.fake_label)
|
649 |
self.fake_label_var = fake_tensor
|
|
|
665 |
|
666 |
|
667 |
class IDMRFLoss(pl.LightningModule):
|
|
|
668 |
def __init__(self, featlayer=VGG19FeatLayer):
|
669 |
super(IDMRFLoss, self).__init__()
|
670 |
self.featlayer = featlayer()
|
|
|
683 |
patch_size = 1
|
684 |
patch_stride = 1
|
685 |
patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(
|
686 |
+
3, patch_size, patch_stride
|
687 |
+
)
|
688 |
self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
|
689 |
dims = self.patches_OIHW.size()
|
690 |
self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
|
|
|
749 |
self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer])
|
750 |
for layer in self.feat_content_layers
|
751 |
]
|
752 |
+
self.content_loss = functools.reduce(
|
753 |
+
lambda x, y: x + y, content_loss_list
|
754 |
+
) * self.lambda_content
|
755 |
|
756 |
return self.style_loss + self.content_loss
|
lib/net/GANLoss.py
CHANGED
@@ -32,13 +32,12 @@ def logistic_loss(fake_pred, real_pred, mode):
|
|
32 |
|
33 |
|
34 |
def r1_loss(real_pred, real_img):
|
35 |
-
(grad_real,) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
|
36 |
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
|
37 |
return grad_penalty
|
38 |
|
39 |
|
40 |
class GANLoss(nn.Module):
|
41 |
-
|
42 |
def __init__(
|
43 |
self,
|
44 |
opt,
|
@@ -64,7 +63,7 @@ class GANLoss(nn.Module):
|
|
64 |
logits_fake = self.discriminator(disc_in_fake)
|
65 |
|
66 |
disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
|
67 |
-
|
68 |
log = {
|
69 |
"disc_loss": disc_loss.detach(),
|
70 |
"logits_real": logits_real.mean().detach(),
|
|
|
32 |
|
33 |
|
34 |
def r1_loss(real_pred, real_img):
|
35 |
+
(grad_real, ) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
|
36 |
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
|
37 |
return grad_penalty
|
38 |
|
39 |
|
40 |
class GANLoss(nn.Module):
|
|
|
41 |
def __init__(
|
42 |
self,
|
43 |
opt,
|
|
|
63 |
logits_fake = self.discriminator(disc_in_fake)
|
64 |
|
65 |
disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
|
66 |
+
|
67 |
log = {
|
68 |
"disc_loss": disc_loss.detach(),
|
69 |
"logits_real": logits_real.mean().detach(),
|
lib/net/IFGeoNet.py
CHANGED
@@ -8,20 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX
|
|
8 |
|
9 |
|
10 |
class SelfAttention(torch.nn.Module):
|
11 |
-
|
12 |
def __init__(self, in_channels, out_channels):
|
13 |
super().__init__()
|
14 |
-
self.conv = nn.Conv3d(in_channels,
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
padding_mode='replicate',
|
24 |
-
bias=False)
|
25 |
with torch.no_grad():
|
26 |
self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
|
27 |
|
@@ -32,38 +29,45 @@ class SelfAttention(torch.nn.Module):
|
|
32 |
|
33 |
|
34 |
class IFGeoNet(nn.Module):
|
35 |
-
|
36 |
def __init__(self, cfg, hidden_dim=256):
|
37 |
super(IFGeoNet, self).__init__()
|
38 |
|
39 |
-
self.conv_in_partial = nn.Conv3d(
|
40 |
-
|
|
|
41 |
|
42 |
-
self.conv_in_smpl = nn.Conv3d(
|
43 |
-
|
|
|
44 |
|
45 |
self.SA = SelfAttention(4, 4)
|
46 |
-
self.conv_0_fusion = nn.Conv3d(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
self.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
self.
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
|
69 |
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
|
@@ -97,21 +101,21 @@ class IFGeoNet(nn.Module):
|
|
97 |
smooth_kernel_size=7,
|
98 |
batch_size=cfg.batch_size,
|
99 |
)
|
100 |
-
|
101 |
self.l1_loss = nn.SmoothL1Loss()
|
102 |
|
103 |
def forward(self, batch):
|
104 |
-
|
105 |
if "body_voxels" in batch.keys():
|
106 |
x_smpl = batch["body_voxels"]
|
107 |
else:
|
108 |
with torch.no_grad():
|
109 |
self.voxelization.update_param(batch["voxel_faces"])
|
110 |
-
x_smpl = self.voxelization(batch["voxel_verts"])[:, 0]
|
111 |
-
|
112 |
p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
|
113 |
-
batch["calib"]).permute(0, 2, 1)
|
114 |
-
x = batch["depth_voxels"]
|
115 |
|
116 |
x = x.unsqueeze(1)
|
117 |
x_smpl = x_smpl.unsqueeze(1)
|
@@ -119,63 +123,67 @@ class IFGeoNet(nn.Module):
|
|
119 |
p = p.unsqueeze(1).unsqueeze(1)
|
120 |
|
121 |
# partial inputs feature extraction
|
122 |
-
feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners
|
123 |
net_partial = self.actvn(self.conv_in_partial(x))
|
124 |
net_partial = self.partial_conv_in_bn(net_partial)
|
125 |
-
net_partial = self.maxpool(net_partial)
|
126 |
|
127 |
# smpl inputs feature extraction
|
128 |
# feature_0_smpl = F.grid_sample(x_smpl, p, padding_mode='border', align_corners = True)
|
129 |
net_smpl = self.actvn(self.conv_in_smpl(x_smpl))
|
130 |
net_smpl = self.smpl_conv_in_bn(net_smpl)
|
131 |
-
net_smpl = self.maxpool(net_smpl)
|
132 |
net_smpl = self.SA(net_smpl)
|
133 |
-
|
134 |
# Feature fusion
|
135 |
net = self.actvn(self.conv_0_fusion(torch.concat([net_partial, net_smpl], dim=1)))
|
136 |
net = self.actvn(self.conv_0_1_fusion(net))
|
137 |
net = self.conv0_1_bn_fusion(net)
|
138 |
-
feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners
|
139 |
# net = self.maxpool(net) # out 64
|
140 |
|
141 |
net = self.actvn(self.conv_0(net))
|
142 |
net = self.actvn(self.conv_0_1(net))
|
143 |
net = self.conv0_1_bn(net)
|
144 |
-
feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners
|
145 |
-
net = self.maxpool(net)
|
146 |
|
147 |
net = self.actvn(self.conv_1(net))
|
148 |
net = self.actvn(self.conv_1_1(net))
|
149 |
net = self.conv1_1_bn(net)
|
150 |
-
feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners
|
151 |
-
net = self.maxpool(net)
|
152 |
|
153 |
net = self.actvn(self.conv_2(net))
|
154 |
net = self.actvn(self.conv_2_1(net))
|
155 |
net = self.conv2_1_bn(net)
|
156 |
-
feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners
|
157 |
-
net = self.maxpool(net)
|
158 |
|
159 |
net = self.actvn(self.conv_3(net))
|
160 |
net = self.actvn(self.conv_3_1(net))
|
161 |
net = self.conv3_1_bn(net)
|
162 |
-
feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners
|
163 |
-
net = self.maxpool(net)
|
164 |
|
165 |
net = self.actvn(self.conv_4(net))
|
166 |
net = self.actvn(self.conv_4_1(net))
|
167 |
net = self.conv4_1_bn(net)
|
168 |
-
feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners
|
169 |
|
170 |
# here every channel corresponse to one feature.
|
171 |
|
172 |
-
features = torch.cat(
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
shape = features.shape
|
176 |
features = torch.reshape(
|
177 |
-
features,
|
178 |
-
|
179 |
# (B, featue_size, samples_num)
|
180 |
features = torch.cat((features, p_features), dim=1)
|
181 |
|
@@ -183,7 +191,7 @@ class IFGeoNet(nn.Module):
|
|
183 |
net = self.actvn(self.fc_1(net))
|
184 |
net = self.actvn(self.fc_2(net))
|
185 |
net = self.fc_out(net).squeeze(1)
|
186 |
-
|
187 |
return net
|
188 |
|
189 |
def compute_loss(self, prds, tgts):
|
|
|
8 |
|
9 |
|
10 |
class SelfAttention(torch.nn.Module):
|
|
|
11 |
def __init__(self, in_channels, out_channels):
|
12 |
super().__init__()
|
13 |
+
self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
|
14 |
+
self.attention = nn.Conv3d(
|
15 |
+
in_channels,
|
16 |
+
out_channels,
|
17 |
+
kernel_size=3,
|
18 |
+
padding=1,
|
19 |
+
padding_mode='replicate',
|
20 |
+
bias=False
|
21 |
+
)
|
|
|
|
|
22 |
with torch.no_grad():
|
23 |
self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
|
24 |
|
|
|
29 |
|
30 |
|
31 |
class IFGeoNet(nn.Module):
|
|
|
32 |
def __init__(self, cfg, hidden_dim=256):
|
33 |
super(IFGeoNet, self).__init__()
|
34 |
|
35 |
+
self.conv_in_partial = nn.Conv3d(
|
36 |
+
1, 16, 3, padding=1, padding_mode='replicate'
|
37 |
+
) # out: 256 ->m.p. 128
|
38 |
|
39 |
+
self.conv_in_smpl = nn.Conv3d(
|
40 |
+
1, 4, 3, padding=1, padding_mode='replicate'
|
41 |
+
) # out: 256 ->m.p. 128
|
42 |
|
43 |
self.SA = SelfAttention(4, 4)
|
44 |
+
self.conv_0_fusion = nn.Conv3d(
|
45 |
+
16 + 4, 32, 3, padding=1, padding_mode='replicate'
|
46 |
+
) # out: 128
|
47 |
+
self.conv_0_1_fusion = nn.Conv3d(
|
48 |
+
32, 32, 3, padding=1, padding_mode='replicate'
|
49 |
+
) # out: 128 ->m.p. 64
|
50 |
+
|
51 |
+
self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
|
52 |
+
self.conv_0_1 = nn.Conv3d(
|
53 |
+
32, 32, 3, padding=1, padding_mode='replicate'
|
54 |
+
) # out: 128 ->m.p. 64
|
55 |
+
|
56 |
+
self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
|
57 |
+
self.conv_1_1 = nn.Conv3d(
|
58 |
+
64, 64, 3, padding=1, padding_mode='replicate'
|
59 |
+
) # out: 64 -> mp 32
|
60 |
+
|
61 |
+
self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
|
62 |
+
self.conv_2_1 = nn.Conv3d(
|
63 |
+
128, 128, 3, padding=1, padding_mode='replicate'
|
64 |
+
) # out: 32 -> mp 16
|
65 |
+
self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
|
66 |
+
self.conv_3_1 = nn.Conv3d(
|
67 |
+
128, 128, 3, padding=1, padding_mode='replicate'
|
68 |
+
) # out: 16 -> mp 8
|
69 |
+
self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
|
70 |
+
self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
|
71 |
|
72 |
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
|
73 |
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
|
|
|
101 |
smooth_kernel_size=7,
|
102 |
batch_size=cfg.batch_size,
|
103 |
)
|
104 |
+
|
105 |
self.l1_loss = nn.SmoothL1Loss()
|
106 |
|
107 |
def forward(self, batch):
|
108 |
+
|
109 |
if "body_voxels" in batch.keys():
|
110 |
x_smpl = batch["body_voxels"]
|
111 |
else:
|
112 |
with torch.no_grad():
|
113 |
self.voxelization.update_param(batch["voxel_faces"])
|
114 |
+
x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128]
|
115 |
+
|
116 |
p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
|
117 |
+
batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
|
118 |
+
x = batch["depth_voxels"] #[B, 128, 128, 128]
|
119 |
|
120 |
x = x.unsqueeze(1)
|
121 |
x_smpl = x_smpl.unsqueeze(1)
|
|
|
123 |
p = p.unsqueeze(1).unsqueeze(1)
|
124 |
|
125 |
# partial inputs feature extraction
|
126 |
+
feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
|
127 |
net_partial = self.actvn(self.conv_in_partial(x))
|
128 |
net_partial = self.partial_conv_in_bn(net_partial)
|
129 |
+
net_partial = self.maxpool(net_partial) # out 64
|
130 |
|
131 |
# smpl inputs feature extraction
|
132 |
# feature_0_smpl = F.grid_sample(x_smpl, p, padding_mode='border', align_corners = True)
|
133 |
net_smpl = self.actvn(self.conv_in_smpl(x_smpl))
|
134 |
net_smpl = self.smpl_conv_in_bn(net_smpl)
|
135 |
+
net_smpl = self.maxpool(net_smpl) # out 64
|
136 |
net_smpl = self.SA(net_smpl)
|
137 |
+
|
138 |
# Feature fusion
|
139 |
net = self.actvn(self.conv_0_fusion(torch.concat([net_partial, net_smpl], dim=1)))
|
140 |
net = self.actvn(self.conv_0_1_fusion(net))
|
141 |
net = self.conv0_1_bn_fusion(net)
|
142 |
+
feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
143 |
# net = self.maxpool(net) # out 64
|
144 |
|
145 |
net = self.actvn(self.conv_0(net))
|
146 |
net = self.actvn(self.conv_0_1(net))
|
147 |
net = self.conv0_1_bn(net)
|
148 |
+
feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
149 |
+
net = self.maxpool(net) # out 32
|
150 |
|
151 |
net = self.actvn(self.conv_1(net))
|
152 |
net = self.actvn(self.conv_1_1(net))
|
153 |
net = self.conv1_1_bn(net)
|
154 |
+
feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
155 |
+
net = self.maxpool(net) # out 16
|
156 |
|
157 |
net = self.actvn(self.conv_2(net))
|
158 |
net = self.actvn(self.conv_2_1(net))
|
159 |
net = self.conv2_1_bn(net)
|
160 |
+
feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
161 |
+
net = self.maxpool(net) # out 8
|
162 |
|
163 |
net = self.actvn(self.conv_3(net))
|
164 |
net = self.actvn(self.conv_3_1(net))
|
165 |
net = self.conv3_1_bn(net)
|
166 |
+
feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
167 |
+
net = self.maxpool(net) # out 4
|
168 |
|
169 |
net = self.actvn(self.conv_4(net))
|
170 |
net = self.actvn(self.conv_4_1(net))
|
171 |
net = self.conv4_1_bn(net)
|
172 |
+
feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
|
173 |
|
174 |
# here every channel corresponse to one feature.
|
175 |
|
176 |
+
features = torch.cat(
|
177 |
+
(
|
178 |
+
feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
|
179 |
+
feature_6
|
180 |
+
),
|
181 |
+
dim=1
|
182 |
+
) # (B, features, 1,7,sample_num)
|
183 |
shape = features.shape
|
184 |
features = torch.reshape(
|
185 |
+
features, (shape[0], shape[1] * shape[3], shape[4])
|
186 |
+
) # (B, featues_per_sample, samples_num)
|
187 |
# (B, featue_size, samples_num)
|
188 |
features = torch.cat((features, p_features), dim=1)
|
189 |
|
|
|
191 |
net = self.actvn(self.fc_1(net))
|
192 |
net = self.actvn(self.fc_2(net))
|
193 |
net = self.fc_out(net).squeeze(1)
|
194 |
+
|
195 |
return net
|
196 |
|
197 |
def compute_loss(self, prds, tgts):
|
lib/net/IFGeoNet_nobody.py
CHANGED
@@ -8,16 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX
|
|
8 |
|
9 |
|
10 |
class SelfAttention(torch.nn.Module):
|
11 |
-
|
12 |
def __init__(self, in_channels, out_channels):
|
13 |
super().__init__()
|
14 |
self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
|
15 |
-
self.attention = nn.Conv3d(
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
with torch.no_grad():
|
22 |
self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
|
23 |
|
@@ -28,34 +29,39 @@ class SelfAttention(torch.nn.Module):
|
|
28 |
|
29 |
|
30 |
class IFGeoNet(nn.Module):
|
31 |
-
|
32 |
def __init__(self, cfg, hidden_dim=256):
|
33 |
super(IFGeoNet, self).__init__()
|
34 |
|
35 |
-
self.conv_in_partial = nn.Conv3d(
|
36 |
-
|
|
|
37 |
|
38 |
self.SA = SelfAttention(4, 4)
|
39 |
-
self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate')
|
40 |
-
self.conv_0_1_fusion = nn.Conv3d(
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
self.
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
self.
|
55 |
-
self.
|
56 |
-
|
57 |
-
|
58 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
|
61 |
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
|
@@ -95,8 +101,8 @@ class IFGeoNet(nn.Module):
|
|
95 |
def forward(self, batch):
|
96 |
|
97 |
p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
|
98 |
-
batch["calib"]).permute(0, 2, 1)
|
99 |
-
x = batch["depth_voxels"]
|
100 |
|
101 |
x = x.unsqueeze(1)
|
102 |
p_features = p.transpose(1, -1)
|
@@ -106,7 +112,7 @@ class IFGeoNet(nn.Module):
|
|
106 |
feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
|
107 |
net_partial = self.actvn(self.conv_in_partial(x))
|
108 |
net_partial = self.partial_conv_in_bn(net_partial)
|
109 |
-
net_partial = self.maxpool(net_partial)
|
110 |
|
111 |
# Feature fusion
|
112 |
net = self.actvn(self.conv_0_fusion(net_partial))
|
@@ -119,40 +125,44 @@ class IFGeoNet(nn.Module):
|
|
119 |
net = self.actvn(self.conv_0_1(net))
|
120 |
net = self.conv0_1_bn(net)
|
121 |
feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
122 |
-
net = self.maxpool(net)
|
123 |
|
124 |
net = self.actvn(self.conv_1(net))
|
125 |
net = self.actvn(self.conv_1_1(net))
|
126 |
net = self.conv1_1_bn(net)
|
127 |
feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
128 |
-
net = self.maxpool(net)
|
129 |
|
130 |
net = self.actvn(self.conv_2(net))
|
131 |
net = self.actvn(self.conv_2_1(net))
|
132 |
net = self.conv2_1_bn(net)
|
133 |
feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
134 |
-
net = self.maxpool(net)
|
135 |
|
136 |
net = self.actvn(self.conv_3(net))
|
137 |
net = self.actvn(self.conv_3_1(net))
|
138 |
net = self.conv3_1_bn(net)
|
139 |
feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
140 |
-
net = self.maxpool(net)
|
141 |
|
142 |
net = self.actvn(self.conv_4(net))
|
143 |
net = self.actvn(self.conv_4_1(net))
|
144 |
net = self.conv4_1_bn(net)
|
145 |
-
feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
146 |
|
147 |
# here every channel corresponse to one feature.
|
148 |
|
149 |
-
features = torch.cat(
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
shape = features.shape
|
153 |
features = torch.reshape(
|
154 |
-
features,
|
155 |
-
|
156 |
# (B, featue_size, samples_num)
|
157 |
features = torch.cat((features, p_features), dim=1)
|
158 |
|
@@ -167,4 +177,4 @@ class IFGeoNet(nn.Module):
|
|
167 |
|
168 |
loss = self.l1_loss(prds, tgts)
|
169 |
|
170 |
-
return loss
|
|
|
8 |
|
9 |
|
10 |
class SelfAttention(torch.nn.Module):
|
|
|
11 |
def __init__(self, in_channels, out_channels):
|
12 |
super().__init__()
|
13 |
self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate')
|
14 |
+
self.attention = nn.Conv3d(
|
15 |
+
in_channels,
|
16 |
+
out_channels,
|
17 |
+
kernel_size=3,
|
18 |
+
padding=1,
|
19 |
+
padding_mode='replicate',
|
20 |
+
bias=False
|
21 |
+
)
|
22 |
with torch.no_grad():
|
23 |
self.attention.weight.copy_(torch.zeros_like(self.attention.weight))
|
24 |
|
|
|
29 |
|
30 |
|
31 |
class IFGeoNet(nn.Module):
|
|
|
32 |
def __init__(self, cfg, hidden_dim=256):
|
33 |
super(IFGeoNet, self).__init__()
|
34 |
|
35 |
+
self.conv_in_partial = nn.Conv3d(
|
36 |
+
1, 16, 3, padding=1, padding_mode='replicate'
|
37 |
+
) # out: 256 ->m.p. 128
|
38 |
|
39 |
self.SA = SelfAttention(4, 4)
|
40 |
+
self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128
|
41 |
+
self.conv_0_1_fusion = nn.Conv3d(
|
42 |
+
32, 32, 3, padding=1, padding_mode='replicate'
|
43 |
+
) # out: 128 ->m.p. 64
|
44 |
+
|
45 |
+
self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128
|
46 |
+
self.conv_0_1 = nn.Conv3d(
|
47 |
+
32, 32, 3, padding=1, padding_mode='replicate'
|
48 |
+
) # out: 128 ->m.p. 64
|
49 |
+
|
50 |
+
self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64
|
51 |
+
self.conv_1_1 = nn.Conv3d(
|
52 |
+
64, 64, 3, padding=1, padding_mode='replicate'
|
53 |
+
) # out: 64 -> mp 32
|
54 |
+
|
55 |
+
self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32
|
56 |
+
self.conv_2_1 = nn.Conv3d(
|
57 |
+
128, 128, 3, padding=1, padding_mode='replicate'
|
58 |
+
) # out: 32 -> mp 16
|
59 |
+
self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16
|
60 |
+
self.conv_3_1 = nn.Conv3d(
|
61 |
+
128, 128, 3, padding=1, padding_mode='replicate'
|
62 |
+
) # out: 16 -> mp 8
|
63 |
+
self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
|
64 |
+
self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8
|
65 |
|
66 |
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3
|
67 |
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1)
|
|
|
101 |
def forward(self, batch):
|
102 |
|
103 |
p = orthogonal(batch["samples_geo"].permute(0, 2, 1),
|
104 |
+
batch["calib"]).permute(0, 2, 1) #[2, 60000, 3]
|
105 |
+
x = batch["depth_voxels"] #[B, 128, 128, 128]
|
106 |
|
107 |
x = x.unsqueeze(1)
|
108 |
p_features = p.transpose(1, -1)
|
|
|
112 |
feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True)
|
113 |
net_partial = self.actvn(self.conv_in_partial(x))
|
114 |
net_partial = self.partial_conv_in_bn(net_partial)
|
115 |
+
net_partial = self.maxpool(net_partial) # out 64
|
116 |
|
117 |
# Feature fusion
|
118 |
net = self.actvn(self.conv_0_fusion(net_partial))
|
|
|
125 |
net = self.actvn(self.conv_0_1(net))
|
126 |
net = self.conv0_1_bn(net)
|
127 |
feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
128 |
+
net = self.maxpool(net) # out 32
|
129 |
|
130 |
net = self.actvn(self.conv_1(net))
|
131 |
net = self.actvn(self.conv_1_1(net))
|
132 |
net = self.conv1_1_bn(net)
|
133 |
feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
134 |
+
net = self.maxpool(net) # out 16
|
135 |
|
136 |
net = self.actvn(self.conv_2(net))
|
137 |
net = self.actvn(self.conv_2_1(net))
|
138 |
net = self.conv2_1_bn(net)
|
139 |
feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
140 |
+
net = self.maxpool(net) # out 8
|
141 |
|
142 |
net = self.actvn(self.conv_3(net))
|
143 |
net = self.actvn(self.conv_3_1(net))
|
144 |
net = self.conv3_1_bn(net)
|
145 |
feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True)
|
146 |
+
net = self.maxpool(net) # out 4
|
147 |
|
148 |
net = self.actvn(self.conv_4(net))
|
149 |
net = self.actvn(self.conv_4_1(net))
|
150 |
net = self.conv4_1_bn(net)
|
151 |
+
feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2
|
152 |
|
153 |
# here every channel corresponse to one feature.
|
154 |
|
155 |
+
features = torch.cat(
|
156 |
+
(
|
157 |
+
feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5,
|
158 |
+
feature_6
|
159 |
+
),
|
160 |
+
dim=1
|
161 |
+
) # (B, features, 1,7,sample_num)
|
162 |
shape = features.shape
|
163 |
features = torch.reshape(
|
164 |
+
features, (shape[0], shape[1] * shape[3], shape[4])
|
165 |
+
) # (B, featues_per_sample, samples_num)
|
166 |
# (B, featue_size, samples_num)
|
167 |
features = torch.cat((features, p_features), dim=1)
|
168 |
|
|
|
177 |
|
178 |
loss = self.l1_loss(prds, tgts)
|
179 |
|
180 |
+
return loss
|
lib/net/NormalNet.py
CHANGED
@@ -35,7 +35,6 @@ class NormalNet(BasePIFuNet):
|
|
35 |
4. Classification.
|
36 |
5. During training, error is calculated on all stacks.
|
37 |
"""
|
38 |
-
|
39 |
def __init__(self, cfg):
|
40 |
|
41 |
super(NormalNet, self).__init__()
|
@@ -65,9 +64,11 @@ class NormalNet(BasePIFuNet):
|
|
65 |
item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
|
66 |
]
|
67 |
self.in_nmlF_dim = sum(
|
68 |
-
[item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"]
|
|
|
69 |
self.in_nmlB_dim = sum(
|
70 |
-
[item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"]
|
|
|
71 |
|
72 |
self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
|
73 |
self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
|
@@ -134,18 +135,20 @@ class NormalNet(BasePIFuNet):
|
|
134 |
if 'mrf' in self.F_losses:
|
135 |
mrf_F_loss = self.mrf_loss(
|
136 |
F.interpolate(prd_F, scale_factor=scale_factor, mode='bicubic', align_corners=True),
|
137 |
-
F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True)
|
|
|
138 |
total_loss["netF"] += self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
|
139 |
total_loss["mrf_F"] = self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
|
140 |
if 'mrf' in self.B_losses:
|
141 |
mrf_B_loss = self.mrf_loss(
|
142 |
F.interpolate(prd_B, scale_factor=scale_factor, mode='bicubic', align_corners=True),
|
143 |
-
F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True)
|
|
|
144 |
total_loss["netB"] += self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
|
145 |
total_loss["mrf_B"] = self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
|
146 |
|
147 |
if 'gan' in self.ALL_losses:
|
148 |
-
|
149 |
total_loss["netD"] = 0.0
|
150 |
|
151 |
pred_fake = self.netD.forward(prd_B)
|
@@ -154,8 +157,8 @@ class NormalNet(BasePIFuNet):
|
|
154 |
loss_D_real = self.gan_loss(pred_real, True)
|
155 |
loss_G_fake = self.gan_loss(pred_fake, True)
|
156 |
|
157 |
-
total_loss["netD"] += 0.5 * (
|
158 |
-
|
159 |
total_loss["D_fake"] = loss_D_fake * self.B_losses_ratio[self.B_losses.index('gan')]
|
160 |
total_loss["D_real"] = loss_D_real * self.B_losses_ratio[self.B_losses.index('gan')]
|
161 |
|
@@ -167,8 +170,8 @@ class NormalNet(BasePIFuNet):
|
|
167 |
for i in range(2):
|
168 |
for j in range(len(pred_fake[i]) - 1):
|
169 |
loss_G_GAN_Feat += self.l1_loss(pred_fake[i][j], pred_real[i][j].detach())
|
170 |
-
total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[
|
171 |
-
'gan_feat')]
|
172 |
total_loss["G_GAN_Feat"] = loss_G_GAN_Feat * self.B_losses_ratio[
|
173 |
self.B_losses.index('gan_feat')]
|
174 |
|
|
|
35 |
4. Classification.
|
36 |
5. During training, error is calculated on all stacks.
|
37 |
"""
|
|
|
38 |
def __init__(self, cfg):
|
39 |
|
40 |
super(NormalNet, self).__init__()
|
|
|
64 |
item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"
|
65 |
]
|
66 |
self.in_nmlF_dim = sum(
|
67 |
+
[item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"]
|
68 |
+
)
|
69 |
self.in_nmlB_dim = sum(
|
70 |
+
[item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"]
|
71 |
+
)
|
72 |
|
73 |
self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
|
74 |
self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance")
|
|
|
135 |
if 'mrf' in self.F_losses:
|
136 |
mrf_F_loss = self.mrf_loss(
|
137 |
F.interpolate(prd_F, scale_factor=scale_factor, mode='bicubic', align_corners=True),
|
138 |
+
F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True)
|
139 |
+
)
|
140 |
total_loss["netF"] += self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
|
141 |
total_loss["mrf_F"] = self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss
|
142 |
if 'mrf' in self.B_losses:
|
143 |
mrf_B_loss = self.mrf_loss(
|
144 |
F.interpolate(prd_B, scale_factor=scale_factor, mode='bicubic', align_corners=True),
|
145 |
+
F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True)
|
146 |
+
)
|
147 |
total_loss["netB"] += self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
|
148 |
total_loss["mrf_B"] = self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss
|
149 |
|
150 |
if 'gan' in self.ALL_losses:
|
151 |
+
|
152 |
total_loss["netD"] = 0.0
|
153 |
|
154 |
pred_fake = self.netD.forward(prd_B)
|
|
|
157 |
loss_D_real = self.gan_loss(pred_real, True)
|
158 |
loss_G_fake = self.gan_loss(pred_fake, True)
|
159 |
|
160 |
+
total_loss["netD"] += 0.5 * (loss_D_fake + loss_D_real
|
161 |
+
) * self.B_losses_ratio[self.B_losses.index('gan')]
|
162 |
total_loss["D_fake"] = loss_D_fake * self.B_losses_ratio[self.B_losses.index('gan')]
|
163 |
total_loss["D_real"] = loss_D_real * self.B_losses_ratio[self.B_losses.index('gan')]
|
164 |
|
|
|
170 |
for i in range(2):
|
171 |
for j in range(len(pred_fake[i]) - 1):
|
172 |
loss_G_GAN_Feat += self.l1_loss(pred_fake[i][j], pred_real[i][j].detach())
|
173 |
+
total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[
|
174 |
+
self.B_losses.index('gan_feat')]
|
175 |
total_loss["G_GAN_Feat"] = loss_G_GAN_Feat * self.B_losses_ratio[
|
176 |
self.B_losses.index('gan_feat')]
|
177 |
|
lib/net/geometry.py
CHANGED
@@ -19,12 +19,12 @@ import numpy as np
|
|
19 |
import numbers
|
20 |
from torch.nn import functional as F
|
21 |
from einops.einops import rearrange
|
22 |
-
|
23 |
"""
|
24 |
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
|
25 |
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
26 |
"""
|
27 |
|
|
|
28 |
def quaternion_to_rotation_matrix(quat):
|
29 |
"""Convert quaternion coefficients to rotation matrix.
|
30 |
Args:
|
@@ -42,11 +42,13 @@ def quaternion_to_rotation_matrix(quat):
|
|
42 |
wx, wy, wz = w * x, w * y, w * z
|
43 |
xy, xz, yz = x * y, x * z, y * z
|
44 |
|
45 |
-
rotMat = torch.stack(
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
50 |
return rotMat
|
51 |
|
52 |
|
@@ -56,7 +58,7 @@ def index(feat, uv):
|
|
56 |
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
|
57 |
:return: [B, C, N] image features at the uv coordinates
|
58 |
"""
|
59 |
-
uv = uv.transpose(1, 2)
|
60 |
|
61 |
(B, N, _) = uv.shape
|
62 |
C = feat.shape[1]
|
@@ -64,14 +66,14 @@ def index(feat, uv):
|
|
64 |
if uv.shape[-1] == 3:
|
65 |
# uv = uv[:,:,[2,1,0]]
|
66 |
# uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
|
67 |
-
uv = uv.unsqueeze(2).unsqueeze(3)
|
68 |
else:
|
69 |
-
uv = uv.unsqueeze(2)
|
70 |
|
71 |
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
|
72 |
# for old versions, simply remove the aligned_corners argument.
|
73 |
-
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True)
|
74 |
-
return samples.view(B, C, N)
|
75 |
|
76 |
|
77 |
def orthogonal(points, calibrations, transforms=None):
|
@@ -84,7 +86,7 @@ def orthogonal(points, calibrations, transforms=None):
|
|
84 |
"""
|
85 |
rot = calibrations[:, :3, :3]
|
86 |
trans = calibrations[:, :3, 3:4]
|
87 |
-
pts = torch.baddbmm(trans, rot, points)
|
88 |
if transforms is not None:
|
89 |
scale = transforms[:2, :2]
|
90 |
shift = transforms[:2, 2:3]
|
@@ -102,7 +104,7 @@ def perspective(points, calibrations, transforms=None):
|
|
102 |
"""
|
103 |
rot = calibrations[:, :3, :3]
|
104 |
trans = calibrations[:, :3, 3:4]
|
105 |
-
homo = torch.baddbmm(trans, rot, points)
|
106 |
xy = homo[:, :2, :] / homo[:, 2:3, :]
|
107 |
if transforms is not None:
|
108 |
scale = transforms[:2, :2]
|
@@ -187,7 +189,8 @@ def rotation_matrix_to_angle_axis(rotation_matrix):
|
|
187 |
if rotation_matrix.shape[1:] == (3, 3):
|
188 |
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
189 |
hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape(
|
190 |
-
1, 3, 1
|
|
|
191 |
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
192 |
|
193 |
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
@@ -222,8 +225,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
|
|
222 |
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
|
223 |
|
224 |
if not quaternion.shape[-1] == 4:
|
225 |
-
raise ValueError(
|
226 |
-
quaternion.shape)
|
|
|
227 |
# unpack input and compute conversion
|
228 |
q1: torch.Tensor = quaternion[..., 1]
|
229 |
q2: torch.Tensor = quaternion[..., 2]
|
@@ -276,11 +280,13 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
|
276 |
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
|
277 |
|
278 |
if len(rotation_matrix.shape) > 3:
|
279 |
-
raise ValueError(
|
280 |
-
rotation_matrix.shape)
|
|
|
281 |
if not rotation_matrix.shape[-2:] == (3, 4):
|
282 |
-
raise ValueError(
|
283 |
-
rotation_matrix.shape)
|
|
|
284 |
|
285 |
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
286 |
|
@@ -347,8 +353,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
|
347 |
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
348 |
|
349 |
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
350 |
-
q /= torch.sqrt(
|
351 |
-
|
|
|
|
|
352 |
q *= 0.5
|
353 |
return q
|
354 |
|
@@ -389,6 +397,7 @@ def rot6d_to_rotmat(x):
|
|
389 |
mat = torch.stack((b1, b2, b3), dim=-1)
|
390 |
return mat
|
391 |
|
|
|
392 |
def rotmat_to_rot6d(x):
|
393 |
"""Convert 3x3 rotation matrix to 6D rotation representation.
|
394 |
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
@@ -402,6 +411,7 @@ def rotmat_to_rot6d(x):
|
|
402 |
x = x.reshape(batch_size, 6)
|
403 |
return x
|
404 |
|
|
|
405 |
def rotmat_to_angle(x):
|
406 |
"""Convert rotation to one-D angle.
|
407 |
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
@@ -440,12 +450,9 @@ def projection(pred_joints, pred_camera, retain_z=False):
|
|
440 |
return pred_keypoints_2d
|
441 |
|
442 |
|
443 |
-
def perspective_projection(
|
444 |
-
|
445 |
-
|
446 |
-
focal_length,
|
447 |
-
camera_center,
|
448 |
-
retain_z=False):
|
449 |
"""
|
450 |
This function computes the perspective projection of a set of points.
|
451 |
Input:
|
@@ -501,10 +508,12 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si
|
|
501 |
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
|
502 |
|
503 |
# least squares
|
504 |
-
Q = np.array(
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
508 |
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
|
509 |
|
510 |
# weighted least squares
|
@@ -558,15 +567,12 @@ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_al
|
|
558 |
S_i = S[i]
|
559 |
joints_i = joints_2d[i]
|
560 |
conf_i = joints_conf[i]
|
561 |
-
trans[i] = estimate_translation_np(
|
562 |
-
|
563 |
-
|
564 |
-
focal_length=focal_length[i],
|
565 |
-
img_size=img_size[i])
|
566 |
return torch.from_numpy(trans).to(device)
|
567 |
|
568 |
|
569 |
-
|
570 |
def Rot_y(angle, category="torch", prepend_dim=True, device=None):
|
571 |
"""Rotate around y-axis by angle
|
572 |
Args:
|
@@ -574,11 +580,13 @@ def Rot_y(angle, category="torch", prepend_dim=True, device=None):
|
|
574 |
prepend_dim: prepend an extra dimension
|
575 |
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
576 |
"""
|
577 |
-
m = np.array(
|
578 |
-
[
|
579 |
-
|
580 |
-
|
581 |
-
|
|
|
|
|
582 |
if category == "torch":
|
583 |
if prepend_dim:
|
584 |
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
|
@@ -600,11 +608,13 @@ def Rot_x(angle, category="torch", prepend_dim=True, device=None):
|
|
600 |
prepend_dim: prepend an extra dimension
|
601 |
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
602 |
"""
|
603 |
-
m = np.array(
|
604 |
-
[
|
605 |
-
|
606 |
-
|
607 |
-
|
|
|
|
|
608 |
if category == "torch":
|
609 |
if prepend_dim:
|
610 |
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
|
@@ -626,11 +636,13 @@ def Rot_z(angle, category="torch", prepend_dim=True, device=None):
|
|
626 |
prepend_dim: prepend an extra dimension
|
627 |
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
628 |
"""
|
629 |
-
m = np.array(
|
630 |
-
[
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
634 |
if category == "torch":
|
635 |
if prepend_dim:
|
636 |
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
|
@@ -672,7 +684,7 @@ def compute_twist_rotation(rotation_matrix, twist_axis):
|
|
672 |
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
|
673 |
twist_aa = quaternion_to_angle_axis(twist_quaternion)
|
674 |
|
675 |
-
twist_angle = torch.sum(twist_aa, dim=1,
|
676 |
-
|
677 |
|
678 |
-
return twist_rotation, twist_angle
|
|
|
19 |
import numbers
|
20 |
from torch.nn import functional as F
|
21 |
from einops.einops import rearrange
|
|
|
22 |
"""
|
23 |
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
|
24 |
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
25 |
"""
|
26 |
|
27 |
+
|
28 |
def quaternion_to_rotation_matrix(quat):
|
29 |
"""Convert quaternion coefficients to rotation matrix.
|
30 |
Args:
|
|
|
42 |
wx, wy, wz = w * x, w * y, w * z
|
43 |
xy, xz, yz = x * y, x * z, y * z
|
44 |
|
45 |
+
rotMat = torch.stack(
|
46 |
+
[
|
47 |
+
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2,
|
48 |
+
2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2
|
49 |
+
],
|
50 |
+
dim=1
|
51 |
+
).view(B, 3, 3)
|
52 |
return rotMat
|
53 |
|
54 |
|
|
|
58 |
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1]
|
59 |
:return: [B, C, N] image features at the uv coordinates
|
60 |
"""
|
61 |
+
uv = uv.transpose(1, 2) # [B, N, 2]
|
62 |
|
63 |
(B, N, _) = uv.shape
|
64 |
C = feat.shape[1]
|
|
|
66 |
if uv.shape[-1] == 3:
|
67 |
# uv = uv[:,:,[2,1,0]]
|
68 |
# uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...]
|
69 |
+
uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3]
|
70 |
else:
|
71 |
+
uv = uv.unsqueeze(2) # [B, N, 1, 2]
|
72 |
|
73 |
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
|
74 |
# for old versions, simply remove the aligned_corners argument.
|
75 |
+
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
|
76 |
+
return samples.view(B, C, N) # [B, C, N]
|
77 |
|
78 |
|
79 |
def orthogonal(points, calibrations, transforms=None):
|
|
|
86 |
"""
|
87 |
rot = calibrations[:, :3, :3]
|
88 |
trans = calibrations[:, :3, 3:4]
|
89 |
+
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
90 |
if transforms is not None:
|
91 |
scale = transforms[:2, :2]
|
92 |
shift = transforms[:2, 2:3]
|
|
|
104 |
"""
|
105 |
rot = calibrations[:, :3, :3]
|
106 |
trans = calibrations[:, :3, 3:4]
|
107 |
+
homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
108 |
xy = homo[:, :2, :] / homo[:, 2:3, :]
|
109 |
if transforms is not None:
|
110 |
scale = transforms[:2, :2]
|
|
|
189 |
if rotation_matrix.shape[1:] == (3, 3):
|
190 |
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
191 |
hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape(
|
192 |
+
1, 3, 1
|
193 |
+
).expand(rot_mat.shape[0], -1, -1)
|
194 |
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
195 |
|
196 |
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
|
|
225 |
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
|
226 |
|
227 |
if not quaternion.shape[-1] == 4:
|
228 |
+
raise ValueError(
|
229 |
+
"Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
|
230 |
+
)
|
231 |
# unpack input and compute conversion
|
232 |
q1: torch.Tensor = quaternion[..., 1]
|
233 |
q2: torch.Tensor = quaternion[..., 2]
|
|
|
280 |
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
|
281 |
|
282 |
if len(rotation_matrix.shape) > 3:
|
283 |
+
raise ValueError(
|
284 |
+
"Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape)
|
285 |
+
)
|
286 |
if not rotation_matrix.shape[-2:] == (3, 4):
|
287 |
+
raise ValueError(
|
288 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(rotation_matrix.shape)
|
289 |
+
)
|
290 |
|
291 |
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
292 |
|
|
|
353 |
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
354 |
|
355 |
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
356 |
+
q /= torch.sqrt(
|
357 |
+
t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa
|
358 |
+
+ t3_rep * mask_c3
|
359 |
+
) # noqa
|
360 |
q *= 0.5
|
361 |
return q
|
362 |
|
|
|
397 |
mat = torch.stack((b1, b2, b3), dim=-1)
|
398 |
return mat
|
399 |
|
400 |
+
|
401 |
def rotmat_to_rot6d(x):
|
402 |
"""Convert 3x3 rotation matrix to 6D rotation representation.
|
403 |
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
|
|
411 |
x = x.reshape(batch_size, 6)
|
412 |
return x
|
413 |
|
414 |
+
|
415 |
def rotmat_to_angle(x):
|
416 |
"""Convert rotation to one-D angle.
|
417 |
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
|
|
450 |
return pred_keypoints_2d
|
451 |
|
452 |
|
453 |
+
def perspective_projection(
|
454 |
+
points, rotation, translation, focal_length, camera_center, retain_z=False
|
455 |
+
):
|
|
|
|
|
|
|
456 |
"""
|
457 |
This function computes the perspective projection of a set of points.
|
458 |
Input:
|
|
|
508 |
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
|
509 |
|
510 |
# least squares
|
511 |
+
Q = np.array(
|
512 |
+
[
|
513 |
+
F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints),
|
514 |
+
O - np.reshape(joints_2d, -1)
|
515 |
+
]
|
516 |
+
).T
|
517 |
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
|
518 |
|
519 |
# weighted least squares
|
|
|
567 |
S_i = S[i]
|
568 |
joints_i = joints_2d[i]
|
569 |
conf_i = joints_conf[i]
|
570 |
+
trans[i] = estimate_translation_np(
|
571 |
+
S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i]
|
572 |
+
)
|
|
|
|
|
573 |
return torch.from_numpy(trans).to(device)
|
574 |
|
575 |
|
|
|
576 |
def Rot_y(angle, category="torch", prepend_dim=True, device=None):
|
577 |
"""Rotate around y-axis by angle
|
578 |
Args:
|
|
|
580 |
prepend_dim: prepend an extra dimension
|
581 |
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
582 |
"""
|
583 |
+
m = np.array(
|
584 |
+
[
|
585 |
+
[np.cos(angle), 0.0, np.sin(angle)],
|
586 |
+
[0.0, 1.0, 0.0],
|
587 |
+
[-np.sin(angle), 0.0, np.cos(angle)],
|
588 |
+
]
|
589 |
+
)
|
590 |
if category == "torch":
|
591 |
if prepend_dim:
|
592 |
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
|
|
|
608 |
prepend_dim: prepend an extra dimension
|
609 |
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
610 |
"""
|
611 |
+
m = np.array(
|
612 |
+
[
|
613 |
+
[1.0, 0.0, 0.0],
|
614 |
+
[0.0, np.cos(angle), -np.sin(angle)],
|
615 |
+
[0.0, np.sin(angle), np.cos(angle)],
|
616 |
+
]
|
617 |
+
)
|
618 |
if category == "torch":
|
619 |
if prepend_dim:
|
620 |
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
|
|
|
636 |
prepend_dim: prepend an extra dimension
|
637 |
Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True)
|
638 |
"""
|
639 |
+
m = np.array(
|
640 |
+
[
|
641 |
+
[np.cos(angle), -np.sin(angle), 0.0],
|
642 |
+
[np.sin(angle), np.cos(angle), 0.0],
|
643 |
+
[0.0, 0.0, 1.0],
|
644 |
+
]
|
645 |
+
)
|
646 |
if category == "torch":
|
647 |
if prepend_dim:
|
648 |
return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0)
|
|
|
684 |
twist_rotation = quaternion_to_rotation_matrix(twist_quaternion)
|
685 |
twist_aa = quaternion_to_angle_axis(twist_quaternion)
|
686 |
|
687 |
+
twist_angle = torch.sum(twist_aa, dim=1,
|
688 |
+
keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True)
|
689 |
|
690 |
+
return twist_rotation, twist_angle
|
lib/net/net_util.py
CHANGED
@@ -71,11 +71,10 @@ def init_weights(net, init_type="normal", init_gain=0.02):
|
|
71 |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
72 |
work better for some applications. Feel free to try yourself.
|
73 |
"""
|
74 |
-
|
75 |
-
def init_func(m): # define the initialization function
|
76 |
classname = m.__class__.__name__
|
77 |
-
if hasattr(m,
|
78 |
-
|
79 |
if init_type == "normal":
|
80 |
init.normal_(m.weight.data, 0.0, init_gain)
|
81 |
elif init_type == "xavier":
|
@@ -85,17 +84,19 @@ def init_weights(net, init_type="normal", init_gain=0.02):
|
|
85 |
elif init_type == "orthogonal":
|
86 |
init.orthogonal_(m.weight.data, gain=init_gain)
|
87 |
else:
|
88 |
-
raise NotImplementedError(
|
89 |
-
|
|
|
90 |
if hasattr(m, "bias") and m.bias is not None:
|
91 |
init.constant_(m.bias.data, 0.0)
|
92 |
-
elif (
|
93 |
-
|
|
|
94 |
init.normal_(m.weight.data, 1.0, init_gain)
|
95 |
init.constant_(m.bias.data, 0.0)
|
96 |
|
97 |
# print('initialize network with %s' % init_type)
|
98 |
-
net.apply(init_func)
|
99 |
|
100 |
|
101 |
def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
|
@@ -110,7 +111,7 @@ def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
|
|
110 |
"""
|
111 |
if len(gpu_ids) > 0:
|
112 |
assert torch.cuda.is_available()
|
113 |
-
net = torch.nn.DataParallel(net)
|
114 |
init_weights(net, init_type, init_gain=init_gain)
|
115 |
return net
|
116 |
|
@@ -127,13 +128,9 @@ def imageSpaceRotation(xy, rot):
|
|
127 |
return (disp * xy).sum(dim=1)
|
128 |
|
129 |
|
130 |
-
def cal_gradient_penalty(
|
131 |
-
|
132 |
-
|
133 |
-
device,
|
134 |
-
type="mixed",
|
135 |
-
constant=1.0,
|
136 |
-
lambda_gp=10.0):
|
137 |
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
138 |
|
139 |
Arguments:
|
@@ -155,9 +152,11 @@ def cal_gradient_penalty(netD,
|
|
155 |
interpolatesv = fake_data
|
156 |
elif type == "mixed":
|
157 |
alpha = torch.rand(real_data.shape[0], 1)
|
158 |
-
alpha = (
|
159 |
-
|
160 |
-
|
|
|
|
|
161 |
alpha = alpha.to(device)
|
162 |
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
163 |
else:
|
@@ -172,9 +171,9 @@ def cal_gradient_penalty(netD,
|
|
172 |
retain_graph=True,
|
173 |
only_inputs=True,
|
174 |
)
|
175 |
-
gradients = gradients[0].view(real_data.size(0), -1)
|
176 |
-
gradient_penalty = ((
|
177 |
-
|
178 |
return gradient_penalty, gradients
|
179 |
else:
|
180 |
return 0.0, None
|
@@ -201,13 +200,11 @@ def get_norm_layer(norm_type="instance"):
|
|
201 |
|
202 |
|
203 |
class Flatten(nn.Module):
|
204 |
-
|
205 |
def forward(self, input):
|
206 |
return input.view(input.size(0), -1)
|
207 |
|
208 |
|
209 |
class ConvBlock(nn.Module):
|
210 |
-
|
211 |
def __init__(self, in_planes, out_planes, opt):
|
212 |
super(ConvBlock, self).__init__()
|
213 |
[k, s, d, p] = opt.conv3x3
|
@@ -258,5 +255,3 @@ class ConvBlock(nn.Module):
|
|
258 |
out3 += residual
|
259 |
|
260 |
return out3
|
261 |
-
|
262 |
-
|
|
|
71 |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
72 |
work better for some applications. Feel free to try yourself.
|
73 |
"""
|
74 |
+
def init_func(m): # define the initialization function
|
|
|
75 |
classname = m.__class__.__name__
|
76 |
+
if hasattr(m,
|
77 |
+
"weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
|
78 |
if init_type == "normal":
|
79 |
init.normal_(m.weight.data, 0.0, init_gain)
|
80 |
elif init_type == "xavier":
|
|
|
84 |
elif init_type == "orthogonal":
|
85 |
init.orthogonal_(m.weight.data, gain=init_gain)
|
86 |
else:
|
87 |
+
raise NotImplementedError(
|
88 |
+
"initialization method [%s] is not implemented" % init_type
|
89 |
+
)
|
90 |
if hasattr(m, "bias") and m.bias is not None:
|
91 |
init.constant_(m.bias.data, 0.0)
|
92 |
+
elif (
|
93 |
+
classname.find("BatchNorm2d") != -1
|
94 |
+
): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
95 |
init.normal_(m.weight.data, 1.0, init_gain)
|
96 |
init.constant_(m.bias.data, 0.0)
|
97 |
|
98 |
# print('initialize network with %s' % init_type)
|
99 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
100 |
|
101 |
|
102 |
def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]):
|
|
|
111 |
"""
|
112 |
if len(gpu_ids) > 0:
|
113 |
assert torch.cuda.is_available()
|
114 |
+
net = torch.nn.DataParallel(net) # multi-GPUs
|
115 |
init_weights(net, init_type, init_gain=init_gain)
|
116 |
return net
|
117 |
|
|
|
128 |
return (disp * xy).sum(dim=1)
|
129 |
|
130 |
|
131 |
+
def cal_gradient_penalty(
|
132 |
+
netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0
|
133 |
+
):
|
|
|
|
|
|
|
|
|
134 |
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
135 |
|
136 |
Arguments:
|
|
|
152 |
interpolatesv = fake_data
|
153 |
elif type == "mixed":
|
154 |
alpha = torch.rand(real_data.shape[0], 1)
|
155 |
+
alpha = (
|
156 |
+
alpha.expand(real_data.shape[0],
|
157 |
+
real_data.nelement() //
|
158 |
+
real_data.shape[0]).contiguous().view(*real_data.shape)
|
159 |
+
)
|
160 |
alpha = alpha.to(device)
|
161 |
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
162 |
else:
|
|
|
171 |
retain_graph=True,
|
172 |
only_inputs=True,
|
173 |
)
|
174 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
175 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)**
|
176 |
+
2).mean() * lambda_gp # added eps
|
177 |
return gradient_penalty, gradients
|
178 |
else:
|
179 |
return 0.0, None
|
|
|
200 |
|
201 |
|
202 |
class Flatten(nn.Module):
|
|
|
203 |
def forward(self, input):
|
204 |
return input.view(input.size(0), -1)
|
205 |
|
206 |
|
207 |
class ConvBlock(nn.Module):
|
|
|
208 |
def __init__(self, in_planes, out_planes, opt):
|
209 |
super(ConvBlock, self).__init__()
|
210 |
[k, s, d, p] = opt.conv3x3
|
|
|
255 |
out3 += residual
|
256 |
|
257 |
return out3
|
|
|
|
lib/net/voxelize.py
CHANGED
@@ -13,7 +13,6 @@ class VoxelizationFunction(Function):
|
|
13 |
Definition of differentiable voxelization function
|
14 |
Currently implemented only for cuda Tensors
|
15 |
"""
|
16 |
-
|
17 |
@staticmethod
|
18 |
def forward(
|
19 |
ctx,
|
@@ -48,12 +47,15 @@ class VoxelizationFunction(Function):
|
|
48 |
smpl_face_code = smpl_face_code.contiguous()
|
49 |
smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
|
50 |
|
51 |
-
occ_volume = torch.cuda.FloatTensor(
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
# occ_volume [B, volume_res, volume_res, volume_res]
|
59 |
# semantic_volume [B, volume_res, volume_res, volume_res, 3]
|
@@ -80,7 +82,6 @@ class Voxelization(nn.Module):
|
|
80 |
"""
|
81 |
Wrapper around the autograd function VoxelizationFunction
|
82 |
"""
|
83 |
-
|
84 |
def __init__(
|
85 |
self,
|
86 |
smpl_vertex_code,
|
@@ -151,21 +152,25 @@ class Voxelization(nn.Module):
|
|
151 |
self.sigma,
|
152 |
self.smooth_kernel_size,
|
153 |
)
|
154 |
-
return vol.permute((0, 4, 1, 2, 3))
|
155 |
|
156 |
def vertices_to_faces(self, vertices):
|
157 |
assert vertices.ndimension() == 3
|
158 |
bs, nv = vertices.shape[:2]
|
159 |
-
face = (
|
160 |
-
|
|
|
|
|
161 |
vertices_ = vertices.reshape((bs * nv, 3))
|
162 |
return vertices_[face.long()]
|
163 |
|
164 |
def vertices_to_tetrahedrons(self, vertices):
|
165 |
assert vertices.ndimension() == 3
|
166 |
bs, nv = vertices.shape[:2]
|
167 |
-
tets = (
|
168 |
-
|
|
|
|
|
169 |
vertices_ = vertices.reshape((bs * nv, 3))
|
170 |
return vertices_[tets.long()]
|
171 |
|
@@ -174,8 +179,9 @@ class Voxelization(nn.Module):
|
|
174 |
assert face_verts.shape[2] == 3
|
175 |
assert face_verts.shape[3] == 3
|
176 |
bs, nf = face_verts.shape[:2]
|
177 |
-
face_centers = (
|
178 |
-
|
|
|
179 |
face_centers = face_centers.reshape((bs, nf, 3))
|
180 |
return face_centers
|
181 |
|
|
|
13 |
Definition of differentiable voxelization function
|
14 |
Currently implemented only for cuda Tensors
|
15 |
"""
|
|
|
16 |
@staticmethod
|
17 |
def forward(
|
18 |
ctx,
|
|
|
47 |
smpl_face_code = smpl_face_code.contiguous()
|
48 |
smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
|
49 |
|
50 |
+
occ_volume = torch.cuda.FloatTensor(
|
51 |
+
ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
|
52 |
+
).fill_(0.0)
|
53 |
+
semantic_volume = torch.cuda.FloatTensor(
|
54 |
+
ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3
|
55 |
+
).fill_(0.0)
|
56 |
+
weight_sum_volume = torch.cuda.FloatTensor(
|
57 |
+
ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
|
58 |
+
).fill_(1e-3)
|
59 |
|
60 |
# occ_volume [B, volume_res, volume_res, volume_res]
|
61 |
# semantic_volume [B, volume_res, volume_res, volume_res, 3]
|
|
|
82 |
"""
|
83 |
Wrapper around the autograd function VoxelizationFunction
|
84 |
"""
|
|
|
85 |
def __init__(
|
86 |
self,
|
87 |
smpl_vertex_code,
|
|
|
152 |
self.sigma,
|
153 |
self.smooth_kernel_size,
|
154 |
)
|
155 |
+
return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
|
156 |
|
157 |
def vertices_to_faces(self, vertices):
|
158 |
assert vertices.ndimension() == 3
|
159 |
bs, nv = vertices.shape[:2]
|
160 |
+
face = (
|
161 |
+
self.smpl_face_indices_batch +
|
162 |
+
(torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
|
163 |
+
)
|
164 |
vertices_ = vertices.reshape((bs * nv, 3))
|
165 |
return vertices_[face.long()]
|
166 |
|
167 |
def vertices_to_tetrahedrons(self, vertices):
|
168 |
assert vertices.ndimension() == 3
|
169 |
bs, nv = vertices.shape[:2]
|
170 |
+
tets = (
|
171 |
+
self.smpl_tetraderon_indices_batch +
|
172 |
+
(torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
|
173 |
+
)
|
174 |
vertices_ = vertices.reshape((bs * nv, 3))
|
175 |
return vertices_[tets.long()]
|
176 |
|
|
|
179 |
assert face_verts.shape[2] == 3
|
180 |
assert face_verts.shape[3] == 3
|
181 |
bs, nf = face_verts.shape[:2]
|
182 |
+
face_centers = (
|
183 |
+
face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :]
|
184 |
+
) / 3.0
|
185 |
face_centers = face_centers.reshape((bs, nf, 3))
|
186 |
return face_centers
|
187 |
|
lib/pixielib/models/FLAME.py
CHANGED
@@ -27,7 +27,6 @@ class FLAMETex(nn.Module):
|
|
27 |
FLAME texture converted from BFM:
|
28 |
https://github.com/TimoBolkart/BFM_to_FLAME
|
29 |
"""
|
30 |
-
|
31 |
def __init__(self, config):
|
32 |
super(FLAMETex, self).__init__()
|
33 |
if config.tex_type == "BFM":
|
@@ -54,8 +53,7 @@ class FLAMETex(nn.Module):
|
|
54 |
n_tex = config.n_tex
|
55 |
num_components = texture_basis.shape[1]
|
56 |
texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
|
57 |
-
texture_basis = torch.from_numpy(
|
58 |
-
texture_basis[:, :n_tex]).float()[None, ...]
|
59 |
self.register_buffer("texture_mean", texture_mean)
|
60 |
self.register_buffer("texture_basis", texture_basis)
|
61 |
|
@@ -64,10 +62,8 @@ class FLAMETex(nn.Module):
|
|
64 |
texcode: [batchsize, n_tex]
|
65 |
texture: [bz, 3, 256, 256], range: 0-1
|
66 |
"""
|
67 |
-
texture = self.texture_mean + (self.texture_basis *
|
68 |
-
|
69 |
-
texture = texture.reshape(texcode.shape[0], 512, 512,
|
70 |
-
3).permute(0, 3, 1, 2)
|
71 |
texture = F.interpolate(texture, [256, 256])
|
72 |
texture = texture[:, [2, 1, 0], :, :]
|
73 |
return texture
|
@@ -78,13 +74,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
|
|
78 |
TODO: pytorch version ==> grid sample
|
79 |
"""
|
80 |
if smplx_texture.shape[0] != smplx_texture.shape[1]:
|
81 |
-
print("SMPL-X texture not squared (%d != %d)" %
|
82 |
-
(smplx_texture[0], smplx_texture[1]))
|
83 |
return
|
84 |
if smplx_texture.shape[0] != cached_data["target_resolution"]:
|
85 |
print(
|
86 |
-
"SMPL-X texture size does not match cached image resolution (%d != %d)"
|
87 |
-
|
|
|
88 |
return
|
89 |
x_coords = cached_data["x_coords"]
|
90 |
y_coords = cached_data["y_coords"]
|
@@ -98,11 +94,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture):
|
|
98 |
flame_texture.shape[0],
|
99 |
).astype(int)
|
100 |
source_tex_coords[:, 1] = np.clip(
|
101 |
-
flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0,
|
102 |
-
|
103 |
|
104 |
smplx_texture[y_coords[target_pixel_ids].astype(int),
|
105 |
-
x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[
|
106 |
-
|
|
|
|
|
107 |
|
108 |
return smplx_texture
|
|
|
27 |
FLAME texture converted from BFM:
|
28 |
https://github.com/TimoBolkart/BFM_to_FLAME
|
29 |
"""
|
|
|
30 |
def __init__(self, config):
|
31 |
super(FLAMETex, self).__init__()
|
32 |
if config.tex_type == "BFM":
|
|
|
53 |
n_tex = config.n_tex
|
54 |
num_components = texture_basis.shape[1]
|
55 |
texture_mean = torch.from_numpy(texture_mean).float()[None, ...]
|
56 |
+
texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...]
|
|
|
57 |
self.register_buffer("texture_mean", texture_mean)
|
58 |
self.register_buffer("texture_basis", texture_basis)
|
59 |
|
|
|
62 |
texcode: [batchsize, n_tex]
|
63 |
texture: [bz, 3, 256, 256], range: 0-1
|
64 |
"""
|
65 |
+
texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1)
|
66 |
+
texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2)
|
|
|
|
|
67 |
texture = F.interpolate(texture, [256, 256])
|
68 |
texture = texture[:, [2, 1, 0], :, :]
|
69 |
return texture
|
|
|
74 |
TODO: pytorch version ==> grid sample
|
75 |
"""
|
76 |
if smplx_texture.shape[0] != smplx_texture.shape[1]:
|
77 |
+
print("SMPL-X texture not squared (%d != %d)" % (smplx_texture[0], smplx_texture[1]))
|
|
|
78 |
return
|
79 |
if smplx_texture.shape[0] != cached_data["target_resolution"]:
|
80 |
print(
|
81 |
+
"SMPL-X texture size does not match cached image resolution (%d != %d)" %
|
82 |
+
(smplx_texture.shape[0], cached_data["target_resolution"])
|
83 |
+
)
|
84 |
return
|
85 |
x_coords = cached_data["x_coords"]
|
86 |
y_coords = cached_data["y_coords"]
|
|
|
94 |
flame_texture.shape[0],
|
95 |
).astype(int)
|
96 |
source_tex_coords[:, 1] = np.clip(
|
97 |
+
flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, flame_texture.shape[1]
|
98 |
+
).astype(int)
|
99 |
|
100 |
smplx_texture[y_coords[target_pixel_ids].astype(int),
|
101 |
+
x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[source_tex_coords[:,
|
102 |
+
0],
|
103 |
+
source_tex_coords[:,
|
104 |
+
1]]
|
105 |
|
106 |
return smplx_texture
|
lib/pixielib/models/SMPLX.py
CHANGED
@@ -209,452 +209,468 @@ extra_names = [
|
|
209 |
SMPLX_names += extra_names
|
210 |
|
211 |
part_indices = {}
|
212 |
-
part_indices["body"] = np.array(
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
# kinematic tree
|
659 |
head_kin_chain = [15, 12, 9, 6, 3, 0]
|
660 |
|
@@ -691,13 +707,12 @@ class SMPLX(nn.Module):
|
|
691 |
Given smplx parameters, this class generates a differentiable SMPLX function
|
692 |
which outputs a mesh and 3D joints
|
693 |
"""
|
694 |
-
|
695 |
def __init__(self, config):
|
696 |
super(SMPLX, self).__init__()
|
697 |
# print("creating the SMPLX Decoder")
|
698 |
ss = np.load(config.smplx_model_path, allow_pickle=True)
|
699 |
smplx_model = Struct(**ss)
|
700 |
-
|
701 |
self.dtype = torch.float32
|
702 |
self.register_buffer(
|
703 |
"faces_tensor",
|
@@ -705,8 +720,8 @@ class SMPLX(nn.Module):
|
|
705 |
)
|
706 |
# The vertices of the template model
|
707 |
self.register_buffer(
|
708 |
-
"v_template",
|
709 |
-
|
710 |
# The shape components and expression
|
711 |
# expression space is the same as FLAME
|
712 |
shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
|
@@ -721,21 +736,18 @@ class SMPLX(nn.Module):
|
|
721 |
# The pose components
|
722 |
num_pose_basis = smplx_model.posedirs.shape[-1]
|
723 |
posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
|
724 |
-
self.register_buffer("posedirs",
|
725 |
-
to_tensor(to_np(posedirs), dtype=self.dtype))
|
726 |
self.register_buffer(
|
727 |
-
"J_regressor",
|
728 |
-
|
729 |
parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
|
730 |
parents[0] = -1
|
731 |
self.register_buffer("parents", parents)
|
732 |
-
self.register_buffer(
|
733 |
-
"lbs_weights",
|
734 |
-
to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
|
735 |
# for face keypoints
|
736 |
self.register_buffer(
|
737 |
-
"lmk_faces_idx",
|
738 |
-
|
739 |
self.register_buffer(
|
740 |
"lmk_bary_coords",
|
741 |
torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
|
@@ -746,24 +758,20 @@ class SMPLX(nn.Module):
|
|
746 |
)
|
747 |
self.register_buffer(
|
748 |
"dynamic_lmk_bary_coords",
|
749 |
-
torch.tensor(smplx_model.dynamic_lmk_bary_coords,
|
750 |
-
dtype=self.dtype),
|
751 |
)
|
752 |
# pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
|
753 |
-
self.register_buffer("head_kin_chain",
|
754 |
-
torch.tensor(head_kin_chain, dtype=torch.long))
|
755 |
|
756 |
# -- initialize parameters
|
757 |
# shape and expression
|
758 |
self.register_buffer(
|
759 |
"shape_params",
|
760 |
-
nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype),
|
761 |
-
requires_grad=False),
|
762 |
)
|
763 |
self.register_buffer(
|
764 |
"expression_params",
|
765 |
-
nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype),
|
766 |
-
requires_grad=False),
|
767 |
)
|
768 |
# pose: represented as rotation matrx [number of joints, 3, 3]
|
769 |
self.register_buffer(
|
@@ -824,8 +832,7 @@ class SMPLX(nn.Module):
|
|
824 |
)
|
825 |
|
826 |
if config.extra_joint_path:
|
827 |
-
self.extra_joint_selector = JointsFromVerticesSelector(
|
828 |
-
fname=config.extra_joint_path)
|
829 |
self.use_joint_regressor = True
|
830 |
self.keypoint_names = SMPLX_names
|
831 |
if self.use_joint_regressor:
|
@@ -843,7 +850,8 @@ class SMPLX(nn.Module):
|
|
843 |
self.register_buffer("target_idxs", torch.from_numpy(target))
|
844 |
self.register_buffer(
|
845 |
"extra_joint_regressor",
|
846 |
-
torch.from_numpy(j14_regressor).to(torch.float32)
|
|
|
847 |
self.part_indices = part_indices
|
848 |
|
849 |
def forward(
|
@@ -880,23 +888,17 @@ class SMPLX(nn.Module):
|
|
880 |
if expression_params is None:
|
881 |
expression_params = self.expression_params.expand(batch_size, -1)
|
882 |
if global_pose is None:
|
883 |
-
global_pose = self.global_pose.unsqueeze(0).expand(
|
884 |
-
batch_size, -1, -1, -1)
|
885 |
if body_pose is None:
|
886 |
-
body_pose = self.body_pose.unsqueeze(0).expand(
|
887 |
-
batch_size, -1, -1, -1)
|
888 |
if jaw_pose is None:
|
889 |
-
jaw_pose = self.jaw_pose.unsqueeze(0).expand(
|
890 |
-
batch_size, -1, -1, -1)
|
891 |
if eye_pose is None:
|
892 |
-
eye_pose = self.eye_pose.unsqueeze(0).expand(
|
893 |
-
batch_size, -1, -1, -1)
|
894 |
if left_hand_pose is None:
|
895 |
-
left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(
|
896 |
-
batch_size, -1, -1, -1)
|
897 |
if right_hand_pose is None:
|
898 |
-
right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(
|
899 |
-
batch_size, -1, -1, -1)
|
900 |
|
901 |
shape_components = torch.cat([shape_params, expression_params], dim=1)
|
902 |
full_pose = torch.cat(
|
@@ -910,8 +912,7 @@ class SMPLX(nn.Module):
|
|
910 |
],
|
911 |
dim=1,
|
912 |
)
|
913 |
-
template_vertices = self.v_template.unsqueeze(0).expand(
|
914 |
-
batch_size, -1, -1)
|
915 |
# smplx
|
916 |
vertices, joints = lbs(
|
917 |
shape_components,
|
@@ -926,10 +927,8 @@ class SMPLX(nn.Module):
|
|
926 |
pose2rot=False,
|
927 |
)
|
928 |
# face dynamic landmarks
|
929 |
-
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(
|
930 |
-
|
931 |
-
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(
|
932 |
-
batch_size, -1, -1)
|
933 |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
|
934 |
vertices,
|
935 |
full_pose,
|
@@ -939,14 +938,12 @@ class SMPLX(nn.Module):
|
|
939 |
)
|
940 |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
|
941 |
lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
|
942 |
-
landmarks = vertices2landmarks(vertices, self.faces_tensor,
|
943 |
-
lmk_faces_idx, lmk_bary_coords)
|
944 |
|
945 |
final_joint_set = [joints, landmarks]
|
946 |
if hasattr(self, "extra_joint_selector"):
|
947 |
# Add any extra joints that might be needed
|
948 |
-
extra_joints = self.extra_joint_selector(vertices,
|
949 |
-
self.faces_tensor)
|
950 |
final_joint_set.append(extra_joints)
|
951 |
# Create the final joint set
|
952 |
joints = torch.cat(final_joint_set, dim=1)
|
@@ -978,16 +975,15 @@ class SMPLX(nn.Module):
|
|
978 |
# -> Left elbow -> Left wrist
|
979 |
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
|
980 |
else:
|
981 |
-
raise NotImplementedError(
|
982 |
-
f"pose_abs2rel does not support: {abs_joint}")
|
983 |
|
984 |
batch_size = global_pose.shape[0]
|
985 |
dtype = global_pose.dtype
|
986 |
device = global_pose.device
|
987 |
full_pose = torch.cat([global_pose, body_pose], dim=1)
|
988 |
-
rel_rot_mat = (
|
989 |
-
|
990 |
-
|
991 |
for idx in kin_chain[1:]:
|
992 |
rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
|
993 |
|
@@ -1027,11 +1023,8 @@ class SMPLX(nn.Module):
|
|
1027 |
# -> Left elbow -> Left wrist
|
1028 |
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
|
1029 |
else:
|
1030 |
-
raise NotImplementedError(
|
1031 |
-
|
1032 |
-
rel_rot_mat = torch.eye(3,
|
1033 |
-
device=full_pose.device,
|
1034 |
-
dtype=full_pose.dtype).unsqueeze_(dim=0)
|
1035 |
for idx in kin_chain:
|
1036 |
rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
|
1037 |
abs_pose = rel_rot_mat[:, None, :, :]
|
|
|
209 |
SMPLX_names += extra_names
|
210 |
|
211 |
part_indices = {}
|
212 |
+
part_indices["body"] = np.array(
|
213 |
+
[
|
214 |
+
0,
|
215 |
+
1,
|
216 |
+
2,
|
217 |
+
3,
|
218 |
+
4,
|
219 |
+
5,
|
220 |
+
6,
|
221 |
+
7,
|
222 |
+
8,
|
223 |
+
9,
|
224 |
+
10,
|
225 |
+
11,
|
226 |
+
12,
|
227 |
+
13,
|
228 |
+
14,
|
229 |
+
15,
|
230 |
+
16,
|
231 |
+
17,
|
232 |
+
18,
|
233 |
+
19,
|
234 |
+
20,
|
235 |
+
21,
|
236 |
+
22,
|
237 |
+
23,
|
238 |
+
24,
|
239 |
+
123,
|
240 |
+
124,
|
241 |
+
125,
|
242 |
+
126,
|
243 |
+
127,
|
244 |
+
132,
|
245 |
+
134,
|
246 |
+
135,
|
247 |
+
136,
|
248 |
+
137,
|
249 |
+
138,
|
250 |
+
143,
|
251 |
+
]
|
252 |
+
)
|
253 |
+
part_indices["torso"] = np.array(
|
254 |
+
[
|
255 |
+
0,
|
256 |
+
1,
|
257 |
+
2,
|
258 |
+
3,
|
259 |
+
6,
|
260 |
+
9,
|
261 |
+
12,
|
262 |
+
13,
|
263 |
+
14,
|
264 |
+
15,
|
265 |
+
16,
|
266 |
+
17,
|
267 |
+
18,
|
268 |
+
19,
|
269 |
+
22,
|
270 |
+
23,
|
271 |
+
24,
|
272 |
+
55,
|
273 |
+
56,
|
274 |
+
57,
|
275 |
+
58,
|
276 |
+
59,
|
277 |
+
76,
|
278 |
+
77,
|
279 |
+
78,
|
280 |
+
79,
|
281 |
+
80,
|
282 |
+
81,
|
283 |
+
82,
|
284 |
+
83,
|
285 |
+
84,
|
286 |
+
85,
|
287 |
+
86,
|
288 |
+
87,
|
289 |
+
88,
|
290 |
+
89,
|
291 |
+
90,
|
292 |
+
91,
|
293 |
+
92,
|
294 |
+
93,
|
295 |
+
94,
|
296 |
+
95,
|
297 |
+
96,
|
298 |
+
97,
|
299 |
+
98,
|
300 |
+
99,
|
301 |
+
100,
|
302 |
+
101,
|
303 |
+
102,
|
304 |
+
103,
|
305 |
+
104,
|
306 |
+
105,
|
307 |
+
106,
|
308 |
+
107,
|
309 |
+
108,
|
310 |
+
109,
|
311 |
+
110,
|
312 |
+
111,
|
313 |
+
112,
|
314 |
+
113,
|
315 |
+
114,
|
316 |
+
115,
|
317 |
+
116,
|
318 |
+
117,
|
319 |
+
118,
|
320 |
+
119,
|
321 |
+
120,
|
322 |
+
121,
|
323 |
+
122,
|
324 |
+
123,
|
325 |
+
124,
|
326 |
+
125,
|
327 |
+
126,
|
328 |
+
127,
|
329 |
+
128,
|
330 |
+
129,
|
331 |
+
130,
|
332 |
+
131,
|
333 |
+
132,
|
334 |
+
133,
|
335 |
+
134,
|
336 |
+
135,
|
337 |
+
136,
|
338 |
+
137,
|
339 |
+
138,
|
340 |
+
139,
|
341 |
+
140,
|
342 |
+
141,
|
343 |
+
142,
|
344 |
+
143,
|
345 |
+
144,
|
346 |
+
]
|
347 |
+
)
|
348 |
+
part_indices["head"] = np.array(
|
349 |
+
[
|
350 |
+
12,
|
351 |
+
15,
|
352 |
+
22,
|
353 |
+
23,
|
354 |
+
24,
|
355 |
+
55,
|
356 |
+
56,
|
357 |
+
57,
|
358 |
+
58,
|
359 |
+
59,
|
360 |
+
60,
|
361 |
+
61,
|
362 |
+
62,
|
363 |
+
63,
|
364 |
+
64,
|
365 |
+
65,
|
366 |
+
66,
|
367 |
+
67,
|
368 |
+
68,
|
369 |
+
69,
|
370 |
+
70,
|
371 |
+
71,
|
372 |
+
72,
|
373 |
+
73,
|
374 |
+
74,
|
375 |
+
75,
|
376 |
+
76,
|
377 |
+
77,
|
378 |
+
78,
|
379 |
+
79,
|
380 |
+
80,
|
381 |
+
81,
|
382 |
+
82,
|
383 |
+
83,
|
384 |
+
84,
|
385 |
+
85,
|
386 |
+
86,
|
387 |
+
87,
|
388 |
+
88,
|
389 |
+
89,
|
390 |
+
90,
|
391 |
+
91,
|
392 |
+
92,
|
393 |
+
93,
|
394 |
+
94,
|
395 |
+
95,
|
396 |
+
96,
|
397 |
+
97,
|
398 |
+
98,
|
399 |
+
99,
|
400 |
+
100,
|
401 |
+
101,
|
402 |
+
102,
|
403 |
+
103,
|
404 |
+
104,
|
405 |
+
105,
|
406 |
+
106,
|
407 |
+
107,
|
408 |
+
108,
|
409 |
+
109,
|
410 |
+
110,
|
411 |
+
111,
|
412 |
+
112,
|
413 |
+
113,
|
414 |
+
114,
|
415 |
+
115,
|
416 |
+
116,
|
417 |
+
117,
|
418 |
+
118,
|
419 |
+
119,
|
420 |
+
120,
|
421 |
+
121,
|
422 |
+
122,
|
423 |
+
123,
|
424 |
+
125,
|
425 |
+
126,
|
426 |
+
134,
|
427 |
+
136,
|
428 |
+
137,
|
429 |
+
]
|
430 |
+
)
|
431 |
+
part_indices["face"] = np.array(
|
432 |
+
[
|
433 |
+
55,
|
434 |
+
56,
|
435 |
+
57,
|
436 |
+
58,
|
437 |
+
59,
|
438 |
+
60,
|
439 |
+
61,
|
440 |
+
62,
|
441 |
+
63,
|
442 |
+
64,
|
443 |
+
65,
|
444 |
+
66,
|
445 |
+
67,
|
446 |
+
68,
|
447 |
+
69,
|
448 |
+
70,
|
449 |
+
71,
|
450 |
+
72,
|
451 |
+
73,
|
452 |
+
74,
|
453 |
+
75,
|
454 |
+
76,
|
455 |
+
77,
|
456 |
+
78,
|
457 |
+
79,
|
458 |
+
80,
|
459 |
+
81,
|
460 |
+
82,
|
461 |
+
83,
|
462 |
+
84,
|
463 |
+
85,
|
464 |
+
86,
|
465 |
+
87,
|
466 |
+
88,
|
467 |
+
89,
|
468 |
+
90,
|
469 |
+
91,
|
470 |
+
92,
|
471 |
+
93,
|
472 |
+
94,
|
473 |
+
95,
|
474 |
+
96,
|
475 |
+
97,
|
476 |
+
98,
|
477 |
+
99,
|
478 |
+
100,
|
479 |
+
101,
|
480 |
+
102,
|
481 |
+
103,
|
482 |
+
104,
|
483 |
+
105,
|
484 |
+
106,
|
485 |
+
107,
|
486 |
+
108,
|
487 |
+
109,
|
488 |
+
110,
|
489 |
+
111,
|
490 |
+
112,
|
491 |
+
113,
|
492 |
+
114,
|
493 |
+
115,
|
494 |
+
116,
|
495 |
+
117,
|
496 |
+
118,
|
497 |
+
119,
|
498 |
+
120,
|
499 |
+
121,
|
500 |
+
122,
|
501 |
+
]
|
502 |
+
)
|
503 |
+
part_indices["upper"] = np.array(
|
504 |
+
[
|
505 |
+
12,
|
506 |
+
13,
|
507 |
+
14,
|
508 |
+
55,
|
509 |
+
56,
|
510 |
+
57,
|
511 |
+
58,
|
512 |
+
59,
|
513 |
+
60,
|
514 |
+
61,
|
515 |
+
62,
|
516 |
+
63,
|
517 |
+
64,
|
518 |
+
65,
|
519 |
+
66,
|
520 |
+
67,
|
521 |
+
68,
|
522 |
+
69,
|
523 |
+
70,
|
524 |
+
71,
|
525 |
+
72,
|
526 |
+
73,
|
527 |
+
74,
|
528 |
+
75,
|
529 |
+
76,
|
530 |
+
77,
|
531 |
+
78,
|
532 |
+
79,
|
533 |
+
80,
|
534 |
+
81,
|
535 |
+
82,
|
536 |
+
83,
|
537 |
+
84,
|
538 |
+
85,
|
539 |
+
86,
|
540 |
+
87,
|
541 |
+
88,
|
542 |
+
89,
|
543 |
+
90,
|
544 |
+
91,
|
545 |
+
92,
|
546 |
+
93,
|
547 |
+
94,
|
548 |
+
95,
|
549 |
+
96,
|
550 |
+
97,
|
551 |
+
98,
|
552 |
+
99,
|
553 |
+
100,
|
554 |
+
101,
|
555 |
+
102,
|
556 |
+
103,
|
557 |
+
104,
|
558 |
+
105,
|
559 |
+
106,
|
560 |
+
107,
|
561 |
+
108,
|
562 |
+
109,
|
563 |
+
110,
|
564 |
+
111,
|
565 |
+
112,
|
566 |
+
113,
|
567 |
+
114,
|
568 |
+
115,
|
569 |
+
116,
|
570 |
+
117,
|
571 |
+
118,
|
572 |
+
119,
|
573 |
+
120,
|
574 |
+
121,
|
575 |
+
122,
|
576 |
+
]
|
577 |
+
)
|
578 |
+
part_indices["hand"] = np.array(
|
579 |
+
[
|
580 |
+
20,
|
581 |
+
21,
|
582 |
+
25,
|
583 |
+
26,
|
584 |
+
27,
|
585 |
+
28,
|
586 |
+
29,
|
587 |
+
30,
|
588 |
+
31,
|
589 |
+
32,
|
590 |
+
33,
|
591 |
+
34,
|
592 |
+
35,
|
593 |
+
36,
|
594 |
+
37,
|
595 |
+
38,
|
596 |
+
39,
|
597 |
+
40,
|
598 |
+
41,
|
599 |
+
42,
|
600 |
+
43,
|
601 |
+
44,
|
602 |
+
45,
|
603 |
+
46,
|
604 |
+
47,
|
605 |
+
48,
|
606 |
+
49,
|
607 |
+
50,
|
608 |
+
51,
|
609 |
+
52,
|
610 |
+
53,
|
611 |
+
54,
|
612 |
+
128,
|
613 |
+
129,
|
614 |
+
130,
|
615 |
+
131,
|
616 |
+
133,
|
617 |
+
139,
|
618 |
+
140,
|
619 |
+
141,
|
620 |
+
142,
|
621 |
+
144,
|
622 |
+
]
|
623 |
+
)
|
624 |
+
part_indices["left_hand"] = np.array(
|
625 |
+
[
|
626 |
+
20,
|
627 |
+
25,
|
628 |
+
26,
|
629 |
+
27,
|
630 |
+
28,
|
631 |
+
29,
|
632 |
+
30,
|
633 |
+
31,
|
634 |
+
32,
|
635 |
+
33,
|
636 |
+
34,
|
637 |
+
35,
|
638 |
+
36,
|
639 |
+
37,
|
640 |
+
38,
|
641 |
+
39,
|
642 |
+
128,
|
643 |
+
129,
|
644 |
+
130,
|
645 |
+
131,
|
646 |
+
133,
|
647 |
+
]
|
648 |
+
)
|
649 |
+
part_indices["right_hand"] = np.array(
|
650 |
+
[
|
651 |
+
21,
|
652 |
+
40,
|
653 |
+
41,
|
654 |
+
42,
|
655 |
+
43,
|
656 |
+
44,
|
657 |
+
45,
|
658 |
+
46,
|
659 |
+
47,
|
660 |
+
48,
|
661 |
+
49,
|
662 |
+
50,
|
663 |
+
51,
|
664 |
+
52,
|
665 |
+
53,
|
666 |
+
54,
|
667 |
+
139,
|
668 |
+
140,
|
669 |
+
141,
|
670 |
+
142,
|
671 |
+
144,
|
672 |
+
]
|
673 |
+
)
|
674 |
# kinematic tree
|
675 |
head_kin_chain = [15, 12, 9, 6, 3, 0]
|
676 |
|
|
|
707 |
Given smplx parameters, this class generates a differentiable SMPLX function
|
708 |
which outputs a mesh and 3D joints
|
709 |
"""
|
|
|
710 |
def __init__(self, config):
|
711 |
super(SMPLX, self).__init__()
|
712 |
# print("creating the SMPLX Decoder")
|
713 |
ss = np.load(config.smplx_model_path, allow_pickle=True)
|
714 |
smplx_model = Struct(**ss)
|
715 |
+
|
716 |
self.dtype = torch.float32
|
717 |
self.register_buffer(
|
718 |
"faces_tensor",
|
|
|
720 |
)
|
721 |
# The vertices of the template model
|
722 |
self.register_buffer(
|
723 |
+
"v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)
|
724 |
+
)
|
725 |
# The shape components and expression
|
726 |
# expression space is the same as FLAME
|
727 |
shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
|
|
|
736 |
# The pose components
|
737 |
num_pose_basis = smplx_model.posedirs.shape[-1]
|
738 |
posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
|
739 |
+
self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
|
|
|
740 |
self.register_buffer(
|
741 |
+
"J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)
|
742 |
+
)
|
743 |
parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
|
744 |
parents[0] = -1
|
745 |
self.register_buffer("parents", parents)
|
746 |
+
self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
|
|
|
|
|
747 |
# for face keypoints
|
748 |
self.register_buffer(
|
749 |
+
"lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)
|
750 |
+
)
|
751 |
self.register_buffer(
|
752 |
"lmk_bary_coords",
|
753 |
torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
|
|
|
758 |
)
|
759 |
self.register_buffer(
|
760 |
"dynamic_lmk_bary_coords",
|
761 |
+
torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype),
|
|
|
762 |
)
|
763 |
# pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
|
764 |
+
self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long))
|
|
|
765 |
|
766 |
# -- initialize parameters
|
767 |
# shape and expression
|
768 |
self.register_buffer(
|
769 |
"shape_params",
|
770 |
+
nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False),
|
|
|
771 |
)
|
772 |
self.register_buffer(
|
773 |
"expression_params",
|
774 |
+
nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False),
|
|
|
775 |
)
|
776 |
# pose: represented as rotation matrx [number of joints, 3, 3]
|
777 |
self.register_buffer(
|
|
|
832 |
)
|
833 |
|
834 |
if config.extra_joint_path:
|
835 |
+
self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path)
|
|
|
836 |
self.use_joint_regressor = True
|
837 |
self.keypoint_names = SMPLX_names
|
838 |
if self.use_joint_regressor:
|
|
|
850 |
self.register_buffer("target_idxs", torch.from_numpy(target))
|
851 |
self.register_buffer(
|
852 |
"extra_joint_regressor",
|
853 |
+
torch.from_numpy(j14_regressor).to(torch.float32)
|
854 |
+
)
|
855 |
self.part_indices = part_indices
|
856 |
|
857 |
def forward(
|
|
|
888 |
if expression_params is None:
|
889 |
expression_params = self.expression_params.expand(batch_size, -1)
|
890 |
if global_pose is None:
|
891 |
+
global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
892 |
if body_pose is None:
|
893 |
+
body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
894 |
if jaw_pose is None:
|
895 |
+
jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
896 |
if eye_pose is None:
|
897 |
+
eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
898 |
if left_hand_pose is None:
|
899 |
+
left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
900 |
if right_hand_pose is None:
|
901 |
+
right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
|
|
902 |
|
903 |
shape_components = torch.cat([shape_params, expression_params], dim=1)
|
904 |
full_pose = torch.cat(
|
|
|
912 |
],
|
913 |
dim=1,
|
914 |
)
|
915 |
+
template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
|
|
|
916 |
# smplx
|
917 |
vertices, joints = lbs(
|
918 |
shape_components,
|
|
|
927 |
pose2rot=False,
|
928 |
)
|
929 |
# face dynamic landmarks
|
930 |
+
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
|
931 |
+
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
|
|
|
|
|
932 |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
|
933 |
vertices,
|
934 |
full_pose,
|
|
|
938 |
)
|
939 |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
|
940 |
lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
|
941 |
+
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)
|
|
|
942 |
|
943 |
final_joint_set = [joints, landmarks]
|
944 |
if hasattr(self, "extra_joint_selector"):
|
945 |
# Add any extra joints that might be needed
|
946 |
+
extra_joints = self.extra_joint_selector(vertices, self.faces_tensor)
|
|
|
947 |
final_joint_set.append(extra_joints)
|
948 |
# Create the final joint set
|
949 |
joints = torch.cat(final_joint_set, dim=1)
|
|
|
975 |
# -> Left elbow -> Left wrist
|
976 |
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
|
977 |
else:
|
978 |
+
raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}")
|
|
|
979 |
|
980 |
batch_size = global_pose.shape[0]
|
981 |
dtype = global_pose.dtype
|
982 |
device = global_pose.device
|
983 |
full_pose = torch.cat([global_pose, body_pose], dim=1)
|
984 |
+
rel_rot_mat = (
|
985 |
+
torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
|
986 |
+
)
|
987 |
for idx in kin_chain[1:]:
|
988 |
rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
|
989 |
|
|
|
1023 |
# -> Left elbow -> Left wrist
|
1024 |
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
|
1025 |
else:
|
1026 |
+
raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}")
|
1027 |
+
rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0)
|
|
|
|
|
|
|
1028 |
for idx in kin_chain:
|
1029 |
rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
|
1030 |
abs_pose = rel_rot_mat[:, None, :, :]
|
lib/pixielib/models/encoders.py
CHANGED
@@ -5,14 +5,13 @@ import torch.nn.functional as F
|
|
5 |
|
6 |
|
7 |
class ResnetEncoder(nn.Module):
|
8 |
-
|
9 |
def __init__(self, append_layers=None):
|
10 |
super(ResnetEncoder, self).__init__()
|
11 |
from . import resnet
|
12 |
|
13 |
# feature_size = 2048
|
14 |
self.feature_dim = 2048
|
15 |
-
self.encoder = resnet.load_ResNet50Model()
|
16 |
# regressor
|
17 |
self.append_layers = append_layers
|
18 |
|
@@ -25,7 +24,6 @@ class ResnetEncoder(nn.Module):
|
|
25 |
|
26 |
|
27 |
class MLP(nn.Module):
|
28 |
-
|
29 |
def __init__(self, channels=[2048, 1024, 1], last_op=None):
|
30 |
super(MLP, self).__init__()
|
31 |
layers = []
|
@@ -45,13 +43,12 @@ class MLP(nn.Module):
|
|
45 |
|
46 |
|
47 |
class HRNEncoder(nn.Module):
|
48 |
-
|
49 |
def __init__(self, append_layers=None):
|
50 |
super(HRNEncoder, self).__init__()
|
51 |
from . import hrnet
|
52 |
|
53 |
self.feature_dim = 2048
|
54 |
-
self.encoder = hrnet.load_HRNet(pretrained=True)
|
55 |
# regressor
|
56 |
self.append_layers = append_layers
|
57 |
|
|
|
5 |
|
6 |
|
7 |
class ResnetEncoder(nn.Module):
|
|
|
8 |
def __init__(self, append_layers=None):
|
9 |
super(ResnetEncoder, self).__init__()
|
10 |
from . import resnet
|
11 |
|
12 |
# feature_size = 2048
|
13 |
self.feature_dim = 2048
|
14 |
+
self.encoder = resnet.load_ResNet50Model() # out: 2048
|
15 |
# regressor
|
16 |
self.append_layers = append_layers
|
17 |
|
|
|
24 |
|
25 |
|
26 |
class MLP(nn.Module):
|
|
|
27 |
def __init__(self, channels=[2048, 1024, 1], last_op=None):
|
28 |
super(MLP, self).__init__()
|
29 |
layers = []
|
|
|
43 |
|
44 |
|
45 |
class HRNEncoder(nn.Module):
|
|
|
46 |
def __init__(self, append_layers=None):
|
47 |
super(HRNEncoder, self).__init__()
|
48 |
from . import hrnet
|
49 |
|
50 |
self.feature_dim = 2048
|
51 |
+
self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048
|
52 |
# regressor
|
53 |
self.append_layers = append_layers
|
54 |
|
lib/pixielib/models/hrnet.py
CHANGED
@@ -15,38 +15,42 @@ def load_HRNet(pretrained=False):
|
|
15 |
hr_net_cfg_dict = {
|
16 |
"use_old_impl": False,
|
17 |
"pretrained_layers": ["*"],
|
18 |
-
"stage1":
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
}
|
51 |
hr_net_cfg = hr_net_cfg_dict
|
52 |
model = HighResolutionNet(hr_net_cfg)
|
@@ -55,7 +59,6 @@ def load_HRNet(pretrained=False):
|
|
55 |
|
56 |
|
57 |
class HighResolutionModule(nn.Module):
|
58 |
-
|
59 |
def __init__(
|
60 |
self,
|
61 |
num_branches,
|
@@ -67,8 +70,7 @@ class HighResolutionModule(nn.Module):
|
|
67 |
multi_scale_output=True,
|
68 |
):
|
69 |
super(HighResolutionModule, self).__init__()
|
70 |
-
self._check_branches(num_branches, blocks, num_blocks, num_inchannels,
|
71 |
-
num_channels)
|
72 |
|
73 |
self.num_inchannels = num_inchannels
|
74 |
self.fuse_method = fuse_method
|
@@ -76,37 +78,33 @@ class HighResolutionModule(nn.Module):
|
|
76 |
|
77 |
self.multi_scale_output = multi_scale_output
|
78 |
|
79 |
-
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
80 |
-
num_channels)
|
81 |
self.fuse_layers = self._make_fuse_layers()
|
82 |
self.relu = nn.ReLU(True)
|
83 |
|
84 |
-
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels,
|
85 |
-
num_channels):
|
86 |
if num_branches != len(num_blocks):
|
87 |
-
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
|
88 |
-
num_branches, len(num_blocks))
|
89 |
raise ValueError(error_msg)
|
90 |
|
91 |
if num_branches != len(num_channels):
|
92 |
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
|
93 |
-
num_branches, len(num_channels)
|
|
|
94 |
raise ValueError(error_msg)
|
95 |
|
96 |
if num_branches != len(num_inchannels):
|
97 |
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
|
98 |
-
num_branches, len(num_inchannels)
|
|
|
99 |
raise ValueError(error_msg)
|
100 |
|
101 |
-
def _make_one_branch(self,
|
102 |
-
branch_index,
|
103 |
-
block,
|
104 |
-
num_blocks,
|
105 |
-
num_channels,
|
106 |
-
stride=1):
|
107 |
downsample = None
|
108 |
-
if (
|
109 |
-
|
|
|
|
|
110 |
downsample = nn.Sequential(
|
111 |
nn.Conv2d(
|
112 |
self.num_inchannels[branch_index],
|
@@ -115,8 +113,7 @@ class HighResolutionModule(nn.Module):
|
|
115 |
stride=stride,
|
116 |
bias=False,
|
117 |
),
|
118 |
-
nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
|
119 |
-
momentum=BN_MOMENTUM),
|
120 |
)
|
121 |
|
122 |
layers = []
|
@@ -126,13 +123,11 @@ class HighResolutionModule(nn.Module):
|
|
126 |
num_channels[branch_index],
|
127 |
stride,
|
128 |
downsample,
|
129 |
-
)
|
130 |
-
|
131 |
-
|
132 |
for i in range(1, num_blocks[branch_index]):
|
133 |
-
layers.append(
|
134 |
-
block(self.num_inchannels[branch_index],
|
135 |
-
num_channels[branch_index]))
|
136 |
|
137 |
return nn.Sequential(*layers)
|
138 |
|
@@ -140,8 +135,7 @@ class HighResolutionModule(nn.Module):
|
|
140 |
branches = []
|
141 |
|
142 |
for i in range(num_branches):
|
143 |
-
branches.append(
|
144 |
-
self._make_one_branch(i, block, num_blocks, num_channels))
|
145 |
|
146 |
return nn.ModuleList(branches)
|
147 |
|
@@ -167,9 +161,9 @@ class HighResolutionModule(nn.Module):
|
|
167 |
bias=False,
|
168 |
),
|
169 |
nn.BatchNorm2d(num_inchannels[i]),
|
170 |
-
nn.Upsample(scale_factor=2**(j - i),
|
171 |
-
|
172 |
-
|
173 |
elif j == i:
|
174 |
fuse_layer.append(None)
|
175 |
else:
|
@@ -188,7 +182,8 @@ class HighResolutionModule(nn.Module):
|
|
188 |
bias=False,
|
189 |
),
|
190 |
nn.BatchNorm2d(num_outchannels_conv3x3),
|
191 |
-
)
|
|
|
192 |
else:
|
193 |
num_outchannels_conv3x3 = num_inchannels[j]
|
194 |
conv3x3s.append(
|
@@ -203,7 +198,8 @@ class HighResolutionModule(nn.Module):
|
|
203 |
),
|
204 |
nn.BatchNorm2d(num_outchannels_conv3x3),
|
205 |
nn.ReLU(True),
|
206 |
-
)
|
|
|
207 |
fuse_layer.append(nn.Sequential(*conv3x3s))
|
208 |
fuse_layers.append(nn.ModuleList(fuse_layer))
|
209 |
|
@@ -237,7 +233,6 @@ blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
|
|
237 |
|
238 |
|
239 |
class HighResolutionNet(nn.Module):
|
240 |
-
|
241 |
def __init__(self, cfg, **kwargs):
|
242 |
self.inplanes = 64
|
243 |
super(HighResolutionNet, self).__init__()
|
@@ -245,19 +240,9 @@ class HighResolutionNet(nn.Module):
|
|
245 |
self.use_old_impl = use_old_impl
|
246 |
|
247 |
# stem net
|
248 |
-
self.conv1 = nn.Conv2d(3,
|
249 |
-
64,
|
250 |
-
kernel_size=3,
|
251 |
-
stride=2,
|
252 |
-
padding=1,
|
253 |
-
bias=False)
|
254 |
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
255 |
-
self.conv2 = nn.Conv2d(64,
|
256 |
-
64,
|
257 |
-
kernel_size=3,
|
258 |
-
stride=2,
|
259 |
-
padding=1,
|
260 |
-
bias=False)
|
261 |
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
262 |
self.relu = nn.ReLU(inplace=True)
|
263 |
|
@@ -271,41 +256,29 @@ class HighResolutionNet(nn.Module):
|
|
271 |
self.stage2_cfg = cfg.get("stage2", {})
|
272 |
num_channels = self.stage2_cfg.get("num_channels", (32, 64))
|
273 |
block = blocks_dict[self.stage2_cfg.get("block")]
|
274 |
-
num_channels = [
|
275 |
-
num_channels[i] * block.expansion for i in range(len(num_channels))
|
276 |
-
]
|
277 |
stage2_num_channels = num_channels
|
278 |
-
self.transition1 = self._make_transition_layer([stage1_out_channel],
|
279 |
-
|
280 |
-
self.stage2, pre_stage_channels = self._make_stage(
|
281 |
-
self.stage2_cfg, num_channels)
|
282 |
|
283 |
self.stage3_cfg = cfg.get("stage3")
|
284 |
num_channels = self.stage3_cfg["num_channels"]
|
285 |
block = blocks_dict[self.stage3_cfg["block"]]
|
286 |
-
num_channels = [
|
287 |
-
num_channels[i] * block.expansion for i in range(len(num_channels))
|
288 |
-
]
|
289 |
stage3_num_channels = num_channels
|
290 |
-
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
291 |
-
|
292 |
-
self.stage3, pre_stage_channels = self._make_stage(
|
293 |
-
self.stage3_cfg, num_channels)
|
294 |
|
295 |
self.stage4_cfg = cfg.get("stage4")
|
296 |
num_channels = self.stage4_cfg["num_channels"]
|
297 |
block = blocks_dict[self.stage4_cfg["block"]]
|
298 |
-
num_channels = [
|
299 |
-
|
300 |
-
]
|
301 |
-
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
302 |
-
num_channels)
|
303 |
stage_4_out_channels = num_channels
|
304 |
|
305 |
self.stage4, pre_stage_channels = self._make_stage(
|
306 |
-
self.stage4_cfg,
|
307 |
-
|
308 |
-
multi_scale_output=not self.use_old_impl)
|
309 |
stage4_num_channels = num_channels
|
310 |
|
311 |
self.output_channels_dim = pre_stage_channels
|
@@ -316,35 +289,34 @@ class HighResolutionNet(nn.Module):
|
|
316 |
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
|
317 |
|
318 |
if use_old_impl:
|
319 |
-
in_dims = (
|
320 |
-
|
321 |
-
|
|
|
322 |
else:
|
323 |
# TODO: Replace with parameters
|
324 |
in_dims = 4 * 384
|
325 |
self.subsample_4 = self._make_subsample_layer(
|
326 |
-
in_channels=stage4_num_channels[0], num_layers=3
|
|
|
327 |
|
328 |
self.subsample_3 = self._make_subsample_layer(
|
329 |
-
in_channels=stage2_num_channels[-1], num_layers=2
|
|
|
330 |
self.subsample_2 = self._make_subsample_layer(
|
331 |
-
in_channels=stage3_num_channels[-1], num_layers=1
|
332 |
-
|
333 |
-
|
334 |
|
335 |
def get_output_dim(self):
|
336 |
-
base_output = {
|
337 |
-
f"layer{idx + 1}": val
|
338 |
-
for idx, val in enumerate(self.output_channels_dim)
|
339 |
-
}
|
340 |
output = base_output.copy()
|
341 |
for key in base_output:
|
342 |
output[f"{key}_avg_pooling"] = output[key]
|
343 |
output["concat"] = 2048
|
344 |
return output
|
345 |
|
346 |
-
def _make_transition_layer(self, num_channels_pre_layer,
|
347 |
-
num_channels_cur_layer):
|
348 |
num_branches_cur = len(num_channels_cur_layer)
|
349 |
num_branches_pre = len(num_channels_pre_layer)
|
350 |
|
@@ -364,26 +336,24 @@ class HighResolutionNet(nn.Module):
|
|
364 |
),
|
365 |
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
366 |
nn.ReLU(inplace=True),
|
367 |
-
)
|
|
|
368 |
else:
|
369 |
transition_layers.append(None)
|
370 |
else:
|
371 |
conv3x3s = []
|
372 |
for j in range(i + 1 - num_branches_pre):
|
373 |
inchannels = num_channels_pre_layer[-1]
|
374 |
-
outchannels = (
|
375 |
-
|
|
|
376 |
conv3x3s.append(
|
377 |
nn.Sequential(
|
378 |
-
nn.Conv2d(inchannels,
|
379 |
-
outchannels,
|
380 |
-
3,
|
381 |
-
2,
|
382 |
-
1,
|
383 |
-
bias=False),
|
384 |
nn.BatchNorm2d(outchannels),
|
385 |
nn.ReLU(inplace=True),
|
386 |
-
)
|
|
|
387 |
transition_layers.append(nn.Sequential(*conv3x3s))
|
388 |
|
389 |
return nn.ModuleList(transition_layers)
|
@@ -410,24 +380,13 @@ class HighResolutionNet(nn.Module):
|
|
410 |
|
411 |
return nn.Sequential(*layers)
|
412 |
|
413 |
-
def _make_conv_layer(self,
|
414 |
-
in_channels=2048,
|
415 |
-
num_layers=3,
|
416 |
-
num_filters=2048,
|
417 |
-
stride=1):
|
418 |
|
419 |
layers = []
|
420 |
for i in range(num_layers):
|
421 |
|
422 |
-
downsample = nn.Conv2d(in_channels,
|
423 |
-
|
424 |
-
stride=1,
|
425 |
-
kernel_size=1,
|
426 |
-
bias=False)
|
427 |
-
layers.append(
|
428 |
-
Bottleneck(in_channels,
|
429 |
-
num_filters // 4,
|
430 |
-
downsample=downsample))
|
431 |
in_channels = num_filters
|
432 |
|
433 |
return nn.Sequential(*layers)
|
@@ -444,18 +403,15 @@ class HighResolutionNet(nn.Module):
|
|
444 |
kernel_size=3,
|
445 |
stride=stride,
|
446 |
padding=1,
|
447 |
-
)
|
|
|
448 |
in_channels = 2 * in_channels
|
449 |
layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM))
|
450 |
layers.append(nn.ReLU(inplace=True))
|
451 |
|
452 |
return nn.Sequential(*layers)
|
453 |
|
454 |
-
def _make_stage(self,
|
455 |
-
layer_config,
|
456 |
-
num_inchannels,
|
457 |
-
multi_scale_output=True,
|
458 |
-
log=False):
|
459 |
num_modules = layer_config["num_modules"]
|
460 |
num_branches = layer_config["num_branches"]
|
461 |
num_blocks = layer_config["num_blocks"]
|
@@ -480,7 +436,8 @@ class HighResolutionNet(nn.Module):
|
|
480 |
num_channels,
|
481 |
fuse_method,
|
482 |
reset_multi_scale_output,
|
483 |
-
)
|
|
|
484 |
modules[-1].log = log
|
485 |
num_inchannels = modules[-1].get_num_inchannels()
|
486 |
|
@@ -580,15 +537,14 @@ class HighResolutionNet(nn.Module):
|
|
580 |
def load_weights(self, pretrained=""):
|
581 |
pretrained = osp.expandvars(pretrained)
|
582 |
if osp.isfile(pretrained):
|
583 |
-
pretrained_state_dict = torch.load(
|
584 |
-
pretrained, map_location=torch.device("cpu"))
|
585 |
|
586 |
need_init_state_dict = {}
|
587 |
for name, m in pretrained_state_dict.items():
|
588 |
-
if (
|
589 |
-
|
|
|
590 |
need_init_state_dict[name] = m
|
591 |
-
missing, unexpected = self.load_state_dict(need_init_state_dict,
|
592 |
-
strict=False)
|
593 |
elif pretrained:
|
594 |
raise ValueError("{} is not exist!".format(pretrained))
|
|
|
15 |
hr_net_cfg_dict = {
|
16 |
"use_old_impl": False,
|
17 |
"pretrained_layers": ["*"],
|
18 |
+
"stage1":
|
19 |
+
{
|
20 |
+
"num_modules": 1,
|
21 |
+
"num_branches": 1,
|
22 |
+
"num_blocks": [4],
|
23 |
+
"num_channels": [64],
|
24 |
+
"block": "BOTTLENECK",
|
25 |
+
"fuse_method": "SUM",
|
26 |
+
},
|
27 |
+
"stage2":
|
28 |
+
{
|
29 |
+
"num_modules": 1,
|
30 |
+
"num_branches": 2,
|
31 |
+
"num_blocks": [4, 4],
|
32 |
+
"num_channels": [48, 96],
|
33 |
+
"block": "BASIC",
|
34 |
+
"fuse_method": "SUM",
|
35 |
+
},
|
36 |
+
"stage3":
|
37 |
+
{
|
38 |
+
"num_modules": 4,
|
39 |
+
"num_branches": 3,
|
40 |
+
"num_blocks": [4, 4, 4],
|
41 |
+
"num_channels": [48, 96, 192],
|
42 |
+
"block": "BASIC",
|
43 |
+
"fuse_method": "SUM",
|
44 |
+
},
|
45 |
+
"stage4":
|
46 |
+
{
|
47 |
+
"num_modules": 3,
|
48 |
+
"num_branches": 4,
|
49 |
+
"num_blocks": [4, 4, 4, 4],
|
50 |
+
"num_channels": [48, 96, 192, 384],
|
51 |
+
"block": "BASIC",
|
52 |
+
"fuse_method": "SUM",
|
53 |
+
},
|
54 |
}
|
55 |
hr_net_cfg = hr_net_cfg_dict
|
56 |
model = HighResolutionNet(hr_net_cfg)
|
|
|
59 |
|
60 |
|
61 |
class HighResolutionModule(nn.Module):
|
|
|
62 |
def __init__(
|
63 |
self,
|
64 |
num_branches,
|
|
|
70 |
multi_scale_output=True,
|
71 |
):
|
72 |
super(HighResolutionModule, self).__init__()
|
73 |
+
self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
|
|
74 |
|
75 |
self.num_inchannels = num_inchannels
|
76 |
self.fuse_method = fuse_method
|
|
|
78 |
|
79 |
self.multi_scale_output = multi_scale_output
|
80 |
|
81 |
+
self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
|
|
|
82 |
self.fuse_layers = self._make_fuse_layers()
|
83 |
self.relu = nn.ReLU(True)
|
84 |
|
85 |
+
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
|
|
|
86 |
if num_branches != len(num_blocks):
|
87 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
|
|
|
88 |
raise ValueError(error_msg)
|
89 |
|
90 |
if num_branches != len(num_channels):
|
91 |
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
|
92 |
+
num_branches, len(num_channels)
|
93 |
+
)
|
94 |
raise ValueError(error_msg)
|
95 |
|
96 |
if num_branches != len(num_inchannels):
|
97 |
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
|
98 |
+
num_branches, len(num_inchannels)
|
99 |
+
)
|
100 |
raise ValueError(error_msg)
|
101 |
|
102 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
|
|
|
|
|
|
|
|
|
|
|
103 |
downsample = None
|
104 |
+
if (
|
105 |
+
stride != 1 or
|
106 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion
|
107 |
+
):
|
108 |
downsample = nn.Sequential(
|
109 |
nn.Conv2d(
|
110 |
self.num_inchannels[branch_index],
|
|
|
113 |
stride=stride,
|
114 |
bias=False,
|
115 |
),
|
116 |
+
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
|
|
|
117 |
)
|
118 |
|
119 |
layers = []
|
|
|
123 |
num_channels[branch_index],
|
124 |
stride,
|
125 |
downsample,
|
126 |
+
)
|
127 |
+
)
|
128 |
+
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
|
129 |
for i in range(1, num_blocks[branch_index]):
|
130 |
+
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
|
|
|
|
|
131 |
|
132 |
return nn.Sequential(*layers)
|
133 |
|
|
|
135 |
branches = []
|
136 |
|
137 |
for i in range(num_branches):
|
138 |
+
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
|
|
|
139 |
|
140 |
return nn.ModuleList(branches)
|
141 |
|
|
|
161 |
bias=False,
|
162 |
),
|
163 |
nn.BatchNorm2d(num_inchannels[i]),
|
164 |
+
nn.Upsample(scale_factor=2**(j - i), mode="nearest"),
|
165 |
+
)
|
166 |
+
)
|
167 |
elif j == i:
|
168 |
fuse_layer.append(None)
|
169 |
else:
|
|
|
182 |
bias=False,
|
183 |
),
|
184 |
nn.BatchNorm2d(num_outchannels_conv3x3),
|
185 |
+
)
|
186 |
+
)
|
187 |
else:
|
188 |
num_outchannels_conv3x3 = num_inchannels[j]
|
189 |
conv3x3s.append(
|
|
|
198 |
),
|
199 |
nn.BatchNorm2d(num_outchannels_conv3x3),
|
200 |
nn.ReLU(True),
|
201 |
+
)
|
202 |
+
)
|
203 |
fuse_layer.append(nn.Sequential(*conv3x3s))
|
204 |
fuse_layers.append(nn.ModuleList(fuse_layer))
|
205 |
|
|
|
233 |
|
234 |
|
235 |
class HighResolutionNet(nn.Module):
|
|
|
236 |
def __init__(self, cfg, **kwargs):
|
237 |
self.inplanes = 64
|
238 |
super(HighResolutionNet, self).__init__()
|
|
|
240 |
self.use_old_impl = use_old_impl
|
241 |
|
242 |
# stem net
|
243 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
|
|
|
|
|
|
|
|
|
|
244 |
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
245 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
|
|
|
|
|
|
|
|
|
|
246 |
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
247 |
self.relu = nn.ReLU(inplace=True)
|
248 |
|
|
|
256 |
self.stage2_cfg = cfg.get("stage2", {})
|
257 |
num_channels = self.stage2_cfg.get("num_channels", (32, 64))
|
258 |
block = blocks_dict[self.stage2_cfg.get("block")]
|
259 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
|
|
|
|
260 |
stage2_num_channels = num_channels
|
261 |
+
self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
|
262 |
+
self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
|
|
|
|
|
263 |
|
264 |
self.stage3_cfg = cfg.get("stage3")
|
265 |
num_channels = self.stage3_cfg["num_channels"]
|
266 |
block = blocks_dict[self.stage3_cfg["block"]]
|
267 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
|
|
|
|
268 |
stage3_num_channels = num_channels
|
269 |
+
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
|
270 |
+
self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
|
|
|
|
|
271 |
|
272 |
self.stage4_cfg = cfg.get("stage4")
|
273 |
num_channels = self.stage4_cfg["num_channels"]
|
274 |
block = blocks_dict[self.stage4_cfg["block"]]
|
275 |
+
num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
|
276 |
+
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
|
|
|
|
|
|
|
277 |
stage_4_out_channels = num_channels
|
278 |
|
279 |
self.stage4, pre_stage_channels = self._make_stage(
|
280 |
+
self.stage4_cfg, num_channels, multi_scale_output=not self.use_old_impl
|
281 |
+
)
|
|
|
282 |
stage4_num_channels = num_channels
|
283 |
|
284 |
self.output_channels_dim = pre_stage_channels
|
|
|
289 |
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
|
290 |
|
291 |
if use_old_impl:
|
292 |
+
in_dims = (
|
293 |
+
2**2 * stage2_num_channels[-1] + 2**1 * stage3_num_channels[-1] +
|
294 |
+
stage_4_out_channels[-1]
|
295 |
+
)
|
296 |
else:
|
297 |
# TODO: Replace with parameters
|
298 |
in_dims = 4 * 384
|
299 |
self.subsample_4 = self._make_subsample_layer(
|
300 |
+
in_channels=stage4_num_channels[0], num_layers=3
|
301 |
+
)
|
302 |
|
303 |
self.subsample_3 = self._make_subsample_layer(
|
304 |
+
in_channels=stage2_num_channels[-1], num_layers=2
|
305 |
+
)
|
306 |
self.subsample_2 = self._make_subsample_layer(
|
307 |
+
in_channels=stage3_num_channels[-1], num_layers=1
|
308 |
+
)
|
309 |
+
self.conv_layers = self._make_conv_layer(in_channels=in_dims, num_layers=5)
|
310 |
|
311 |
def get_output_dim(self):
|
312 |
+
base_output = {f"layer{idx + 1}": val for idx, val in enumerate(self.output_channels_dim)}
|
|
|
|
|
|
|
313 |
output = base_output.copy()
|
314 |
for key in base_output:
|
315 |
output[f"{key}_avg_pooling"] = output[key]
|
316 |
output["concat"] = 2048
|
317 |
return output
|
318 |
|
319 |
+
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
|
|
|
320 |
num_branches_cur = len(num_channels_cur_layer)
|
321 |
num_branches_pre = len(num_channels_pre_layer)
|
322 |
|
|
|
336 |
),
|
337 |
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
338 |
nn.ReLU(inplace=True),
|
339 |
+
)
|
340 |
+
)
|
341 |
else:
|
342 |
transition_layers.append(None)
|
343 |
else:
|
344 |
conv3x3s = []
|
345 |
for j in range(i + 1 - num_branches_pre):
|
346 |
inchannels = num_channels_pre_layer[-1]
|
347 |
+
outchannels = (
|
348 |
+
num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
|
349 |
+
)
|
350 |
conv3x3s.append(
|
351 |
nn.Sequential(
|
352 |
+
nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
|
|
|
|
|
|
|
|
|
|
|
353 |
nn.BatchNorm2d(outchannels),
|
354 |
nn.ReLU(inplace=True),
|
355 |
+
)
|
356 |
+
)
|
357 |
transition_layers.append(nn.Sequential(*conv3x3s))
|
358 |
|
359 |
return nn.ModuleList(transition_layers)
|
|
|
380 |
|
381 |
return nn.Sequential(*layers)
|
382 |
|
383 |
+
def _make_conv_layer(self, in_channels=2048, num_layers=3, num_filters=2048, stride=1):
|
|
|
|
|
|
|
|
|
384 |
|
385 |
layers = []
|
386 |
for i in range(num_layers):
|
387 |
|
388 |
+
downsample = nn.Conv2d(in_channels, num_filters, stride=1, kernel_size=1, bias=False)
|
389 |
+
layers.append(Bottleneck(in_channels, num_filters // 4, downsample=downsample))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
in_channels = num_filters
|
391 |
|
392 |
return nn.Sequential(*layers)
|
|
|
403 |
kernel_size=3,
|
404 |
stride=stride,
|
405 |
padding=1,
|
406 |
+
)
|
407 |
+
)
|
408 |
in_channels = 2 * in_channels
|
409 |
layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM))
|
410 |
layers.append(nn.ReLU(inplace=True))
|
411 |
|
412 |
return nn.Sequential(*layers)
|
413 |
|
414 |
+
def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True, log=False):
|
|
|
|
|
|
|
|
|
415 |
num_modules = layer_config["num_modules"]
|
416 |
num_branches = layer_config["num_branches"]
|
417 |
num_blocks = layer_config["num_blocks"]
|
|
|
436 |
num_channels,
|
437 |
fuse_method,
|
438 |
reset_multi_scale_output,
|
439 |
+
)
|
440 |
+
)
|
441 |
modules[-1].log = log
|
442 |
num_inchannels = modules[-1].get_num_inchannels()
|
443 |
|
|
|
537 |
def load_weights(self, pretrained=""):
|
538 |
pretrained = osp.expandvars(pretrained)
|
539 |
if osp.isfile(pretrained):
|
540 |
+
pretrained_state_dict = torch.load(pretrained, map_location=torch.device("cpu"))
|
|
|
541 |
|
542 |
need_init_state_dict = {}
|
543 |
for name, m in pretrained_state_dict.items():
|
544 |
+
if (
|
545 |
+
name.split(".")[0] in self.pretrained_layers or self.pretrained_layers[0] == "*"
|
546 |
+
):
|
547 |
need_init_state_dict[name] = m
|
548 |
+
missing, unexpected = self.load_state_dict(need_init_state_dict, strict=False)
|
|
|
549 |
elif pretrained:
|
550 |
raise ValueError("{} is not exist!".format(pretrained))
|
lib/pixielib/models/lbs.py
CHANGED
@@ -30,8 +30,7 @@ def rot_mat_to_euler(rot_mats):
|
|
30 |
# Calculates rotation matrix to euler angles
|
31 |
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
|
32 |
|
33 |
-
sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
|
34 |
-
rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
|
35 |
return torch.atan2(-rot_mats[:, 2, 0], sy)
|
36 |
|
37 |
|
@@ -86,15 +85,13 @@ def find_dynamic_lmk_idx_and_bcoords(
|
|
86 |
# aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
|
87 |
rot_mats = torch.index_select(pose, 1, head_kin_chain)
|
88 |
|
89 |
-
rel_rot_mat = torch.eye(3, device=vertices.device,
|
90 |
-
dtype=dtype).unsqueeze_(dim=0)
|
91 |
for idx in range(len(head_kin_chain)):
|
92 |
# rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
|
93 |
rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat)
|
94 |
|
95 |
-
y_rot_angle = torch.round(
|
96 |
-
|
97 |
-
max=39)).to(dtype=torch.long)
|
98 |
# print(y_rot_angle[0])
|
99 |
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
|
100 |
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
|
@@ -102,8 +99,7 @@ def find_dynamic_lmk_idx_and_bcoords(
|
|
102 |
y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle
|
103 |
# print(y_rot_angle[0])
|
104 |
|
105 |
-
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0,
|
106 |
-
y_rot_angle)
|
107 |
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle)
|
108 |
|
109 |
return dyn_lmk_faces_idx, dyn_lmk_b_coords
|
@@ -135,11 +131,11 @@ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
|
|
135 |
batch_size, num_verts = vertices.shape[:2]
|
136 |
device = vertices.device
|
137 |
|
138 |
-
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
|
139 |
-
batch_size, -1, 3)
|
140 |
|
141 |
-
lmk_faces += (
|
142 |
-
|
|
|
143 |
|
144 |
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
|
145 |
|
@@ -211,13 +207,11 @@ def lbs(
|
|
211 |
# N x J x 3 x 3
|
212 |
ident = torch.eye(3, dtype=dtype, device=device)
|
213 |
if pose2rot:
|
214 |
-
rot_mats = batch_rodrigues(pose.view(-1, 3),
|
215 |
-
dtype=dtype).view([batch_size, -1, 3, 3])
|
216 |
|
217 |
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
|
218 |
# (N x P) x (P, V * 3) -> N x V x 3
|
219 |
-
pose_offsets = torch.matmul(pose_feature,
|
220 |
-
posedirs).view(batch_size, -1, 3)
|
221 |
else:
|
222 |
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
|
223 |
rot_mats = pose.view(batch_size, -1, 3, 3)
|
@@ -234,12 +228,9 @@ def lbs(
|
|
234 |
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
235 |
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
236 |
num_joints = J_regressor.shape[0]
|
237 |
-
T = torch.matmul(W, A.view(batch_size, num_joints,
|
238 |
-
16)).view(batch_size, -1, 4, 4)
|
239 |
|
240 |
-
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
|
241 |
-
dtype=dtype,
|
242 |
-
device=device)
|
243 |
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
|
244 |
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
245 |
|
@@ -318,8 +309,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
|
318 |
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
319 |
|
320 |
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
321 |
-
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
|
322 |
-
dim=1).view((batch_size, 3, 3))
|
323 |
|
324 |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
325 |
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
@@ -335,9 +325,7 @@ def transform_mat(R, t):
|
|
335 |
- T: Bx4x4 Transformation matrix
|
336 |
"""
|
337 |
# No padding left or right, only add an extra row
|
338 |
-
return torch.cat([F.pad(R, [0, 0, 0, 1]),
|
339 |
-
F.pad(t, [0, 0, 0, 1], value=1)],
|
340 |
-
dim=2)
|
341 |
|
342 |
|
343 |
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
@@ -370,15 +358,13 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
|
370 |
rel_joints[:, 1:] -= joints[:, parents[1:]]
|
371 |
|
372 |
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
|
373 |
-
rel_joints.reshape(-1, 3, 1)).reshape(
|
374 |
-
-1, joints.shape[1], 4, 4)
|
375 |
|
376 |
transform_chain = [transforms_mat[:, 0]]
|
377 |
for i in range(1, parents.shape[0]):
|
378 |
# Subtract the joint location at the rest pose
|
379 |
# No need for rotation, since it's identity when at rest
|
380 |
-
curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:,
|
381 |
-
i])
|
382 |
transform_chain.append(curr_res)
|
383 |
|
384 |
transforms = torch.stack(transform_chain, dim=1)
|
@@ -392,21 +378,22 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
|
392 |
joints_homogen = F.pad(joints, [0, 0, 0, 1])
|
393 |
|
394 |
rel_transforms = transforms - F.pad(
|
395 |
-
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
|
|
|
396 |
|
397 |
return posed_joints, rel_transforms
|
398 |
|
399 |
|
400 |
class JointsFromVerticesSelector(nn.Module):
|
401 |
-
|
402 |
def __init__(self, fname):
|
403 |
"""Selects extra joints from vertices"""
|
404 |
super(JointsFromVerticesSelector, self).__init__()
|
405 |
|
406 |
err_msg = ("Either pass a filename or triangle face ids, names and"
|
407 |
" barycentrics")
|
408 |
-
assert fname is not None or (
|
409 |
-
|
|
|
410 |
if fname is not None:
|
411 |
fname = os.path.expanduser(os.path.expandvars(fname))
|
412 |
with open(fname, "r") as f:
|
@@ -422,13 +409,11 @@ class JointsFromVerticesSelector(nn.Module):
|
|
422 |
assert len(bcs) == len(
|
423 |
face_ids
|
424 |
), "The number of barycentric coordinates must be equal to the faces"
|
425 |
-
assert len(names) == len(
|
426 |
-
face_ids), "The number of names must be equal to the number of "
|
427 |
|
428 |
self.names = names
|
429 |
self.register_buffer("bcs", torch.tensor(bcs, dtype=torch.float32))
|
430 |
-
self.register_buffer("face_ids",
|
431 |
-
torch.tensor(face_ids, dtype=torch.long))
|
432 |
|
433 |
def extra_joint_names(self):
|
434 |
"""Returns the names of the extra joints"""
|
@@ -439,8 +424,7 @@ class JointsFromVerticesSelector(nn.Module):
|
|
439 |
return []
|
440 |
vertex_ids = faces[self.face_ids].reshape(-1)
|
441 |
# Should be BxNx3x3
|
442 |
-
triangles = torch.index_select(vertices, 1, vertex_ids).reshape(
|
443 |
-
-1, len(self.bcs), 3, 3)
|
444 |
return (triangles * self.bcs[None, :, :, None]).sum(dim=2)
|
445 |
|
446 |
|
@@ -463,7 +447,6 @@ def to_np(array, dtype=np.float32):
|
|
463 |
|
464 |
|
465 |
class Struct(object):
|
466 |
-
|
467 |
def __init__(self, **kwargs):
|
468 |
for key, val in kwargs.items():
|
469 |
setattr(self, key, val)
|
|
|
30 |
# Calculates rotation matrix to euler angles
|
31 |
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
|
32 |
|
33 |
+
sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
|
|
|
34 |
return torch.atan2(-rot_mats[:, 2, 0], sy)
|
35 |
|
36 |
|
|
|
85 |
# aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
|
86 |
rot_mats = torch.index_select(pose, 1, head_kin_chain)
|
87 |
|
88 |
+
rel_rot_mat = torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0)
|
|
|
89 |
for idx in range(len(head_kin_chain)):
|
90 |
# rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
|
91 |
rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat)
|
92 |
|
93 |
+
y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
|
94 |
+
max=39)).to(dtype=torch.long)
|
|
|
95 |
# print(y_rot_angle[0])
|
96 |
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
|
97 |
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
|
|
|
99 |
y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle
|
100 |
# print(y_rot_angle[0])
|
101 |
|
102 |
+
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle)
|
|
|
103 |
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle)
|
104 |
|
105 |
return dyn_lmk_faces_idx, dyn_lmk_b_coords
|
|
|
131 |
batch_size, num_verts = vertices.shape[:2]
|
132 |
device = vertices.device
|
133 |
|
134 |
+
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3)
|
|
|
135 |
|
136 |
+
lmk_faces += (
|
137 |
+
torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
|
138 |
+
)
|
139 |
|
140 |
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
|
141 |
|
|
|
207 |
# N x J x 3 x 3
|
208 |
ident = torch.eye(3, dtype=dtype, device=device)
|
209 |
if pose2rot:
|
210 |
+
rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
|
|
|
211 |
|
212 |
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
|
213 |
# (N x P) x (P, V * 3) -> N x V x 3
|
214 |
+
pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
|
|
|
215 |
else:
|
216 |
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
|
217 |
rot_mats = pose.view(batch_size, -1, 3, 3)
|
|
|
228 |
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
229 |
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
230 |
num_joints = J_regressor.shape[0]
|
231 |
+
T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
|
|
|
232 |
|
233 |
+
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device)
|
|
|
|
|
234 |
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
|
235 |
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
236 |
|
|
|
309 |
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
310 |
|
311 |
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
312 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
|
|
|
313 |
|
314 |
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
315 |
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
|
|
325 |
- T: Bx4x4 Transformation matrix
|
326 |
"""
|
327 |
# No padding left or right, only add an extra row
|
328 |
+
return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
|
|
|
|
|
329 |
|
330 |
|
331 |
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
|
|
358 |
rel_joints[:, 1:] -= joints[:, parents[1:]]
|
359 |
|
360 |
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
|
361 |
+
rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
|
|
|
362 |
|
363 |
transform_chain = [transforms_mat[:, 0]]
|
364 |
for i in range(1, parents.shape[0]):
|
365 |
# Subtract the joint location at the rest pose
|
366 |
# No need for rotation, since it's identity when at rest
|
367 |
+
curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i])
|
|
|
368 |
transform_chain.append(curr_res)
|
369 |
|
370 |
transforms = torch.stack(transform_chain, dim=1)
|
|
|
378 |
joints_homogen = F.pad(joints, [0, 0, 0, 1])
|
379 |
|
380 |
rel_transforms = transforms - F.pad(
|
381 |
+
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
|
382 |
+
)
|
383 |
|
384 |
return posed_joints, rel_transforms
|
385 |
|
386 |
|
387 |
class JointsFromVerticesSelector(nn.Module):
|
|
|
388 |
def __init__(self, fname):
|
389 |
"""Selects extra joints from vertices"""
|
390 |
super(JointsFromVerticesSelector, self).__init__()
|
391 |
|
392 |
err_msg = ("Either pass a filename or triangle face ids, names and"
|
393 |
" barycentrics")
|
394 |
+
assert fname is not None or (
|
395 |
+
face_ids is not None and bcs is not None and names is not None
|
396 |
+
), err_msg
|
397 |
if fname is not None:
|
398 |
fname = os.path.expanduser(os.path.expandvars(fname))
|
399 |
with open(fname, "r") as f:
|
|
|
409 |
assert len(bcs) == len(
|
410 |
face_ids
|
411 |
), "The number of barycentric coordinates must be equal to the faces"
|
412 |
+
assert len(names) == len(face_ids), "The number of names must be equal to the number of "
|
|
|
413 |
|
414 |
self.names = names
|
415 |
self.register_buffer("bcs", torch.tensor(bcs, dtype=torch.float32))
|
416 |
+
self.register_buffer("face_ids", torch.tensor(face_ids, dtype=torch.long))
|
|
|
417 |
|
418 |
def extra_joint_names(self):
|
419 |
"""Returns the names of the extra joints"""
|
|
|
424 |
return []
|
425 |
vertex_ids = faces[self.face_ids].reshape(-1)
|
426 |
# Should be BxNx3x3
|
427 |
+
triangles = torch.index_select(vertices, 1, vertex_ids).reshape(-1, len(self.bcs), 3, 3)
|
|
|
428 |
return (triangles * self.bcs[None, :, :, None]).sum(dim=2)
|
429 |
|
430 |
|
|
|
447 |
|
448 |
|
449 |
class Struct(object):
|
|
|
450 |
def __init__(self, **kwargs):
|
451 |
for key, val in kwargs.items():
|
452 |
setattr(self, key, val)
|
lib/pixielib/models/moderators.py
CHANGED
@@ -12,11 +12,7 @@ import torch.nn.functional as F
|
|
12 |
|
13 |
|
14 |
class TempSoftmaxFusion(nn.Module):
|
15 |
-
|
16 |
-
def __init__(self,
|
17 |
-
channels=[2048 * 2, 1024, 1],
|
18 |
-
detach_inputs=False,
|
19 |
-
detach_feature=False):
|
20 |
super(TempSoftmaxFusion, self).__init__()
|
21 |
self.detach_inputs = detach_inputs
|
22 |
self.detach_feature = detach_feature
|
@@ -63,11 +59,7 @@ class TempSoftmaxFusion(nn.Module):
|
|
63 |
|
64 |
|
65 |
class GumbelSoftmaxFusion(nn.Module):
|
66 |
-
|
67 |
-
def __init__(self,
|
68 |
-
channels=[2048 * 2, 1024, 1],
|
69 |
-
detach_inputs=False,
|
70 |
-
detach_feature=False):
|
71 |
super(GumbelSoftmaxFusion, self).__init__()
|
72 |
self.detach_inputs = detach_inputs
|
73 |
self.detach_feature = detach_feature
|
|
|
12 |
|
13 |
|
14 |
class TempSoftmaxFusion(nn.Module):
|
15 |
+
def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False):
|
|
|
|
|
|
|
|
|
16 |
super(TempSoftmaxFusion, self).__init__()
|
17 |
self.detach_inputs = detach_inputs
|
18 |
self.detach_feature = detach_feature
|
|
|
59 |
|
60 |
|
61 |
class GumbelSoftmaxFusion(nn.Module):
|
62 |
+
def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False):
|
|
|
|
|
|
|
|
|
63 |
super(GumbelSoftmaxFusion, self).__init__()
|
64 |
self.detach_inputs = detach_inputs
|
65 |
self.detach_feature = detach_feature
|
lib/pixielib/models/resnet.py
CHANGED
@@ -22,16 +22,10 @@ from torchvision import models
|
|
22 |
|
23 |
|
24 |
class ResNet(nn.Module):
|
25 |
-
|
26 |
def __init__(self, block, layers, num_classes=1000):
|
27 |
self.inplanes = 64
|
28 |
super(ResNet, self).__init__()
|
29 |
-
self.conv1 = nn.Conv2d(3,
|
30 |
-
64,
|
31 |
-
kernel_size=7,
|
32 |
-
stride=2,
|
33 |
-
padding=3,
|
34 |
-
bias=False)
|
35 |
self.bn1 = nn.BatchNorm2d(64)
|
36 |
self.relu = nn.ReLU(inplace=True)
|
37 |
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
@@ -98,12 +92,7 @@ class Bottleneck(nn.Module):
|
|
98 |
super(Bottleneck, self).__init__()
|
99 |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
100 |
self.bn1 = nn.BatchNorm2d(planes)
|
101 |
-
self.conv2 = nn.Conv2d(planes,
|
102 |
-
planes,
|
103 |
-
kernel_size=3,
|
104 |
-
stride=stride,
|
105 |
-
padding=1,
|
106 |
-
bias=False)
|
107 |
self.bn2 = nn.BatchNorm2d(planes)
|
108 |
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
109 |
self.bn3 = nn.BatchNorm2d(planes * 4)
|
@@ -136,12 +125,7 @@ class Bottleneck(nn.Module):
|
|
136 |
|
137 |
def conv3x3(in_planes, out_planes, stride=1):
|
138 |
"""3x3 convolution with padding"""
|
139 |
-
return nn.Conv2d(in_planes,
|
140 |
-
out_planes,
|
141 |
-
kernel_size=3,
|
142 |
-
stride=stride,
|
143 |
-
padding=1,
|
144 |
-
bias=False)
|
145 |
|
146 |
|
147 |
class BasicBlock(nn.Module):
|
@@ -196,8 +180,7 @@ def load_ResNet50Model():
|
|
196 |
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
197 |
copy_parameter_from_resnet(
|
198 |
model,
|
199 |
-
torchvision.models.resnet50(
|
200 |
-
weights=models.ResNet50_Weights.DEFAULT).state_dict(),
|
201 |
)
|
202 |
return model
|
203 |
|
@@ -206,8 +189,7 @@ def load_ResNet101Model():
|
|
206 |
model = ResNet(Bottleneck, [3, 4, 23, 3])
|
207 |
copy_parameter_from_resnet(
|
208 |
model,
|
209 |
-
torchvision.models.resnet101(
|
210 |
-
weights=models.ResNet101_Weights.DEFAULT).state_dict(),
|
211 |
)
|
212 |
return model
|
213 |
|
@@ -216,8 +198,7 @@ def load_ResNet152Model():
|
|
216 |
model = ResNet(Bottleneck, [3, 8, 36, 3])
|
217 |
copy_parameter_from_resnet(
|
218 |
model,
|
219 |
-
torchvision.models.resnet152(
|
220 |
-
weights=models.ResNet152_Weights.DEFAULT).state_dict(),
|
221 |
)
|
222 |
return model
|
223 |
|
@@ -229,7 +210,6 @@ def load_ResNet152Model():
|
|
229 |
|
230 |
class DoubleConv(nn.Module):
|
231 |
"""(convolution => [BN] => ReLU) * 2"""
|
232 |
-
|
233 |
def __init__(self, in_channels, out_channels):
|
234 |
super().__init__()
|
235 |
self.double_conv = nn.Sequential(
|
@@ -247,11 +227,9 @@ class DoubleConv(nn.Module):
|
|
247 |
|
248 |
class Down(nn.Module):
|
249 |
"""Downscaling with maxpool then double conv"""
|
250 |
-
|
251 |
def __init__(self, in_channels, out_channels):
|
252 |
super().__init__()
|
253 |
-
self.maxpool_conv = nn.Sequential(
|
254 |
-
nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
|
255 |
|
256 |
def forward(self, x):
|
257 |
return self.maxpool_conv(x)
|
@@ -259,20 +237,16 @@ class Down(nn.Module):
|
|
259 |
|
260 |
class Up(nn.Module):
|
261 |
"""Upscaling then double conv"""
|
262 |
-
|
263 |
def __init__(self, in_channels, out_channels, bilinear=True):
|
264 |
super().__init__()
|
265 |
|
266 |
# if bilinear, use the normal convolutions to reduce the number of channels
|
267 |
if bilinear:
|
268 |
-
self.up = nn.Upsample(scale_factor=2,
|
269 |
-
mode="bilinear",
|
270 |
-
align_corners=True)
|
271 |
else:
|
272 |
-
self.up = nn.ConvTranspose2d(
|
273 |
-
|
274 |
-
|
275 |
-
stride=2)
|
276 |
|
277 |
self.conv = DoubleConv(in_channels, out_channels)
|
278 |
|
@@ -282,9 +256,7 @@ class Up(nn.Module):
|
|
282 |
diffY = x2.size()[2] - x1.size()[2]
|
283 |
diffX = x2.size()[3] - x1.size()[3]
|
284 |
|
285 |
-
x1 = F.pad(
|
286 |
-
x1,
|
287 |
-
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
288 |
# if you have padding issues, see
|
289 |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
290 |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
@@ -293,7 +265,6 @@ class Up(nn.Module):
|
|
293 |
|
294 |
|
295 |
class OutConv(nn.Module):
|
296 |
-
|
297 |
def __init__(self, in_channels, out_channels):
|
298 |
super(OutConv, self).__init__()
|
299 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
@@ -303,7 +274,6 @@ class OutConv(nn.Module):
|
|
303 |
|
304 |
|
305 |
class UNet(nn.Module):
|
306 |
-
|
307 |
def __init__(self, n_channels, n_classes, bilinear=True):
|
308 |
super(UNet, self).__init__()
|
309 |
self.n_channels = n_channels
|
|
|
22 |
|
23 |
|
24 |
class ResNet(nn.Module):
|
|
|
25 |
def __init__(self, block, layers, num_classes=1000):
|
26 |
self.inplanes = 64
|
27 |
super(ResNet, self).__init__()
|
28 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
|
|
|
|
|
|
|
29 |
self.bn1 = nn.BatchNorm2d(64)
|
30 |
self.relu = nn.ReLU(inplace=True)
|
31 |
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
92 |
super(Bottleneck, self).__init__()
|
93 |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
94 |
self.bn1 = nn.BatchNorm2d(planes)
|
95 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
|
|
|
|
|
|
|
96 |
self.bn2 = nn.BatchNorm2d(planes)
|
97 |
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
98 |
self.bn3 = nn.BatchNorm2d(planes * 4)
|
|
|
125 |
|
126 |
def conv3x3(in_planes, out_planes, stride=1):
|
127 |
"""3x3 convolution with padding"""
|
128 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
|
131 |
class BasicBlock(nn.Module):
|
|
|
180 |
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
181 |
copy_parameter_from_resnet(
|
182 |
model,
|
183 |
+
torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT).state_dict(),
|
|
|
184 |
)
|
185 |
return model
|
186 |
|
|
|
189 |
model = ResNet(Bottleneck, [3, 4, 23, 3])
|
190 |
copy_parameter_from_resnet(
|
191 |
model,
|
192 |
+
torchvision.models.resnet101(weights=models.ResNet101_Weights.DEFAULT).state_dict(),
|
|
|
193 |
)
|
194 |
return model
|
195 |
|
|
|
198 |
model = ResNet(Bottleneck, [3, 8, 36, 3])
|
199 |
copy_parameter_from_resnet(
|
200 |
model,
|
201 |
+
torchvision.models.resnet152(weights=models.ResNet152_Weights.DEFAULT).state_dict(),
|
|
|
202 |
)
|
203 |
return model
|
204 |
|
|
|
210 |
|
211 |
class DoubleConv(nn.Module):
|
212 |
"""(convolution => [BN] => ReLU) * 2"""
|
|
|
213 |
def __init__(self, in_channels, out_channels):
|
214 |
super().__init__()
|
215 |
self.double_conv = nn.Sequential(
|
|
|
227 |
|
228 |
class Down(nn.Module):
|
229 |
"""Downscaling with maxpool then double conv"""
|
|
|
230 |
def __init__(self, in_channels, out_channels):
|
231 |
super().__init__()
|
232 |
+
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
|
|
|
233 |
|
234 |
def forward(self, x):
|
235 |
return self.maxpool_conv(x)
|
|
|
237 |
|
238 |
class Up(nn.Module):
|
239 |
"""Upscaling then double conv"""
|
|
|
240 |
def __init__(self, in_channels, out_channels, bilinear=True):
|
241 |
super().__init__()
|
242 |
|
243 |
# if bilinear, use the normal convolutions to reduce the number of channels
|
244 |
if bilinear:
|
245 |
+
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
|
|
|
|
246 |
else:
|
247 |
+
self.up = nn.ConvTranspose2d(
|
248 |
+
in_channels // 2, in_channels // 2, kernel_size=2, stride=2
|
249 |
+
)
|
|
|
250 |
|
251 |
self.conv = DoubleConv(in_channels, out_channels)
|
252 |
|
|
|
256 |
diffY = x2.size()[2] - x1.size()[2]
|
257 |
diffX = x2.size()[3] - x1.size()[3]
|
258 |
|
259 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
|
|
|
|
260 |
# if you have padding issues, see
|
261 |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
262 |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
|
|
265 |
|
266 |
|
267 |
class OutConv(nn.Module):
|
|
|
268 |
def __init__(self, in_channels, out_channels):
|
269 |
super(OutConv, self).__init__()
|
270 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
|
|
274 |
|
275 |
|
276 |
class UNet(nn.Module):
|
|
|
277 |
def __init__(self, n_channels, n_classes, bilinear=True):
|
278 |
super(UNet, self).__init__()
|
279 |
self.n_channels = n_channels
|
lib/pixielib/pixie.py
CHANGED
@@ -33,7 +33,6 @@ from .utils.config import cfg
|
|
33 |
|
34 |
|
35 |
class PIXIE(object):
|
36 |
-
|
37 |
def __init__(self, config=None, device="cuda:0"):
|
38 |
if config is None:
|
39 |
self.cfg = cfg
|
@@ -45,10 +44,7 @@ class PIXIE(object):
|
|
45 |
self.param_list_dict = {}
|
46 |
for lst in self.cfg.params.keys():
|
47 |
param_list = cfg.params.get(lst)
|
48 |
-
self.param_list_dict[lst] = {
|
49 |
-
i: cfg.model.get("n_" + i)
|
50 |
-
for i in param_list
|
51 |
-
}
|
52 |
|
53 |
# Build the models
|
54 |
self._create_model()
|
@@ -97,24 +93,19 @@ class PIXIE(object):
|
|
97 |
self.Regressor = {}
|
98 |
for key in self.cfg.network.regressor.keys():
|
99 |
n_output = sum(self.param_list_dict[f"{key}_list"].values())
|
100 |
-
channels = ([2048] + self.cfg.network.regressor.get(key).channels +
|
101 |
-
[n_output])
|
102 |
if self.cfg.network.regressor.get(key).type == "mlp":
|
103 |
self.Regressor[key] = MLP(channels=channels).to(self.device)
|
104 |
-
self.model_dict[f"Regressor_{key}"] = self.Regressor[
|
105 |
-
key].state_dict()
|
106 |
|
107 |
# Build the extractors
|
108 |
# to extract separate head/left hand/right hand feature from body feature
|
109 |
self.Extractor = {}
|
110 |
for key in self.cfg.network.extractor.keys():
|
111 |
-
channels = [
|
112 |
-
2048
|
113 |
-
] + self.cfg.network.extractor.get(key).channels + [2048]
|
114 |
if self.cfg.network.extractor.get(key).type == "mlp":
|
115 |
self.Extractor[key] = MLP(channels=channels).to(self.device)
|
116 |
-
self.model_dict[f"Extractor_{key}"] = self.Extractor[
|
117 |
-
key].state_dict()
|
118 |
|
119 |
# Build the moderators
|
120 |
self.Moderator = {}
|
@@ -122,15 +113,13 @@ class PIXIE(object):
|
|
122 |
share_part = key.split("_")[0]
|
123 |
detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
|
124 |
detach_feature = self.cfg.network.moderator.get(key).detach_feature
|
125 |
-
channels = [2048 * 2
|
126 |
-
] + self.cfg.network.moderator.get(key).channels + [2]
|
127 |
self.Moderator[key] = TempSoftmaxFusion(
|
128 |
detach_inputs=detach_inputs,
|
129 |
detach_feature=detach_feature,
|
130 |
channels=channels,
|
131 |
).to(self.device)
|
132 |
-
self.model_dict[f"Moderator_{key}"] = self.Moderator[
|
133 |
-
key].state_dict()
|
134 |
|
135 |
# Build the SMPL-X body model, which we also use to represent faces and
|
136 |
# hands, using the relevant parts only
|
@@ -147,9 +136,7 @@ class PIXIE(object):
|
|
147 |
print(f"pixie trained model path: {model_path} does not exist!")
|
148 |
exit()
|
149 |
# eval mode
|
150 |
-
for module in [
|
151 |
-
self.Encoder, self.Regressor, self.Moderator, self.Extractor
|
152 |
-
]:
|
153 |
for net in module.values():
|
154 |
net.eval()
|
155 |
|
@@ -185,14 +172,14 @@ class PIXIE(object):
|
|
185 |
# crop
|
186 |
cropper_key = "hand" if "hand" in part_key else part_key
|
187 |
points_scale = image.shape[-2:]
|
188 |
-
cropped_image, tform = self.Cropper[cropper_key].crop(
|
189 |
-
image, points_for_crop, points_scale)
|
190 |
# transform points(must be normalized to [-1.1]) accordingly
|
191 |
cropped_points_dict = {}
|
192 |
for points_key in points_dict.keys():
|
193 |
points = points_dict[points_key]
|
194 |
cropped_points = self.Cropper[cropper_key].transform_points(
|
195 |
-
points, tform, points_scale, normalize=True
|
|
|
196 |
cropped_points_dict[points_key] = cropped_points
|
197 |
return cropped_image, cropped_points_dict
|
198 |
|
@@ -244,8 +231,7 @@ class PIXIE(object):
|
|
244 |
# then predict share parameters
|
245 |
feature[key][f"{key}_share"] = feature[key][key]
|
246 |
share_dict = self.decompose_code(
|
247 |
-
self.Regressor[f"{part}_share"](
|
248 |
-
feature[key][f"{part}_share"]),
|
249 |
self.param_list_dict[f"{part}_share_list"],
|
250 |
)
|
251 |
# compose parameters
|
@@ -257,13 +243,16 @@ class PIXIE(object):
|
|
257 |
f_body = feature["body"]["body"]
|
258 |
# extract part feature
|
259 |
for part_name in ["head", "left_hand", "right_hand"]:
|
260 |
-
feature["body"][f"{part_name}_share"] = self.Extractor[
|
261 |
-
|
|
|
262 |
|
263 |
# -- check if part crops are given, if not, crop parts by coarse body estimation
|
264 |
-
if (
|
265 |
-
|
266 |
-
|
|
|
|
|
267 |
# - run without fusion to get coarse estimation, for cropping parts
|
268 |
# body only
|
269 |
body_dict = self.decompose_code(
|
@@ -272,29 +261,26 @@ class PIXIE(object):
|
|
272 |
)
|
273 |
# head share
|
274 |
head_share_dict = self.decompose_code(
|
275 |
-
self.Regressor["head" + "_share"](
|
276 |
-
feature[key]["head" + "_share"]),
|
277 |
self.param_list_dict["head" + "_share_list"],
|
278 |
)
|
279 |
# right hand share
|
280 |
right_hand_share_dict = self.decompose_code(
|
281 |
-
self.Regressor["hand" + "_share"](
|
282 |
-
feature[key]["right_hand" + "_share"]),
|
283 |
self.param_list_dict["hand" + "_share_list"],
|
284 |
)
|
285 |
# left hand share
|
286 |
left_hand_share_dict = self.decompose_code(
|
287 |
-
self.Regressor["hand" + "_share"](
|
288 |
-
feature[key]["left_hand" + "_share"]),
|
289 |
self.param_list_dict["hand" + "_share_list"],
|
290 |
)
|
291 |
# change the dict name from right to left
|
292 |
-
left_hand_share_dict[
|
293 |
-
"
|
294 |
-
|
295 |
-
left_hand_share_dict[
|
296 |
-
"
|
297 |
-
|
298 |
param_dict[key] = {
|
299 |
**body_dict,
|
300 |
**head_share_dict,
|
@@ -304,21 +290,18 @@ class PIXIE(object):
|
|
304 |
if body_only:
|
305 |
param_dict["moderator_weight"] = None
|
306 |
return param_dict
|
307 |
-
prediction_body_only = self.decode(param_dict[key],
|
308 |
-
param_type="body")
|
309 |
# crop
|
310 |
for part_name in ["head", "left_hand", "right_hand"]:
|
311 |
part = part_name.split("_")[-1]
|
312 |
points_dict = {
|
313 |
-
"smplx_kpt":
|
314 |
-
prediction_body_only["
|
315 |
-
"trans_verts":
|
316 |
-
prediction_body_only["transformed_vertices"],
|
317 |
}
|
318 |
-
image_hd = torchvision.transforms.Resize(1024)(
|
319 |
-
data["body"]["image"])
|
320 |
cropped_image, cropped_joints_dict = self.part_from_body(
|
321 |
-
image_hd, part_name, points_dict
|
|
|
322 |
data[key][part_name + "_image"] = cropped_image
|
323 |
|
324 |
# -- encode features from part crops, then fuse feature using the weight from moderator
|
@@ -338,16 +321,12 @@ class PIXIE(object):
|
|
338 |
self.Regressor[f"{part}_share"](f_part),
|
339 |
self.param_list_dict[f"{part}_share_list"],
|
340 |
)
|
341 |
-
param_dict["body_" + part_name] = {
|
342 |
-
**part_dict,
|
343 |
-
**part_share_dict
|
344 |
-
}
|
345 |
|
346 |
# moderator to assign weight, then integrate features
|
347 |
-
f_body_out, f_part_out, f_weight = self.Moderator[
|
348 |
-
|
349 |
-
|
350 |
-
work=True)
|
351 |
if copy_and_paste:
|
352 |
# copy and paste strategy always trusts the results from part
|
353 |
feature["body"][f"{part_name}_share"] = f_part
|
@@ -355,8 +334,9 @@ class PIXIE(object):
|
|
355 |
# for hand, if part weight > 0.7 (very confident, then fully trust part)
|
356 |
part_w = f_weight[:, [1]]
|
357 |
part_w[part_w > 0.7] = 1.0
|
358 |
-
f_body_out = (
|
359 |
-
|
|
|
360 |
feature["body"][f"{part_name}_share"] = f_body_out
|
361 |
else:
|
362 |
feature["body"][f"{part_name}_share"] = f_body_out
|
@@ -367,29 +347,24 @@ class PIXIE(object):
|
|
367 |
# -- predict parameters from fused body feature
|
368 |
# head share
|
369 |
head_share_dict = self.decompose_code(
|
370 |
-
self.Regressor["head" + "_share"](feature[key]["head" +
|
371 |
-
"_share"]),
|
372 |
self.param_list_dict["head" + "_share_list"],
|
373 |
)
|
374 |
# right hand share
|
375 |
right_hand_share_dict = self.decompose_code(
|
376 |
-
self.Regressor["hand" + "_share"](
|
377 |
-
feature[key]["right_hand" + "_share"]),
|
378 |
self.param_list_dict["hand" + "_share_list"],
|
379 |
)
|
380 |
# left hand share
|
381 |
left_hand_share_dict = self.decompose_code(
|
382 |
-
self.Regressor["hand" + "_share"](
|
383 |
-
feature[key]["left_hand" + "_share"]),
|
384 |
self.param_list_dict["hand" + "_share_list"],
|
385 |
)
|
386 |
# change the dict name from right to left
|
387 |
-
left_hand_share_dict[
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
"left_wrist_pose"] = left_hand_share_dict.pop(
|
392 |
-
"right_wrist_pose")
|
393 |
param_dict["body"] = {
|
394 |
**body_dict,
|
395 |
**head_share_dict,
|
@@ -403,10 +378,10 @@ class PIXIE(object):
|
|
403 |
if keep_local:
|
404 |
# for local change that will not affect whole body and produce unnatral pose, trust part
|
405 |
param_dict[key]["exp"] = param_dict["body_head"]["exp"]
|
406 |
-
param_dict[key]["right_hand_pose"] = param_dict[
|
407 |
-
"
|
408 |
-
param_dict[key]["left_hand_pose"] = param_dict[
|
409 |
-
"
|
410 |
|
411 |
return param_dict
|
412 |
|
@@ -426,75 +401,70 @@ class PIXIE(object):
|
|
426 |
if "pose" in key and "jaw" not in key:
|
427 |
param_dict[key] = converter.batch_cont2matrix(param_dict[key])
|
428 |
if param_type == "body" or param_type == "head":
|
429 |
-
param_dict["jaw_pose"] = converter.batch_euler2matrix(
|
430 |
-
|
431 |
|
432 |
# complement params if it's not in given param dict
|
433 |
if param_type == "head":
|
434 |
batch_size = param_dict["shape"].shape[0]
|
435 |
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
|
436 |
param_dict["global_pose"] = param_dict["head_pose"]
|
437 |
-
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(
|
438 |
-
|
439 |
-
|
440 |
-
-1)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
|
441 |
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
442 |
-
batch_size, -1, -1, -1
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
param_dict["
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
452 |
elif param_type == "hand":
|
453 |
batch_size = param_dict["right_hand_pose"].shape[0]
|
454 |
-
param_dict["abs_right_wrist_pose"] = param_dict[
|
455 |
-
"right_wrist_pose"].clone()
|
456 |
dtype = param_dict["right_hand_pose"].dtype
|
457 |
device = param_dict["right_hand_pose"].device
|
458 |
-
x_180_pose = (torch.eye(3, dtype=dtype,
|
459 |
-
device=device).unsqueeze(0).repeat(
|
460 |
-
1, 1, 1))
|
461 |
x_180_pose[0, 2, 2] = -1.0
|
462 |
x_180_pose[0, 1, 1] = -1.0
|
463 |
-
param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(
|
464 |
-
|
465 |
-
param_dict["
|
466 |
-
batch_size, -1)
|
467 |
-
param_dict["exp"] = self.smplx.expression_params.expand(
|
468 |
-
batch_size, -1)
|
469 |
param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
|
470 |
-
batch_size, -1, -1, -1
|
|
|
471 |
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
472 |
-
batch_size, -1, -1, -1
|
473 |
-
|
474 |
-
|
475 |
-
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(
|
482 |
-
|
|
|
483 |
elif param_type == "body":
|
484 |
# the predcition from the head and hand share regressor is always absolute pose
|
485 |
batch_size = param_dict["shape"].shape[0]
|
486 |
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
|
487 |
-
param_dict["abs_right_wrist_pose"] = param_dict[
|
488 |
-
|
489 |
-
param_dict["abs_left_wrist_pose"] = param_dict[
|
490 |
-
"left_wrist_pose"].clone()
|
491 |
# the body-hand share regressor is working for right hand
|
492 |
# so we assume body network get the flipped feature for the left hand. then get the parameters
|
493 |
# then we need to flip it back to left, which matches the input left hand
|
494 |
-
param_dict["left_wrist_pose"] = util.flip_pose(
|
495 |
-
|
496 |
-
param_dict["left_hand_pose"] = util.flip_pose(
|
497 |
-
param_dict["left_hand_pose"])
|
498 |
else:
|
499 |
exit()
|
500 |
|
@@ -508,8 +478,7 @@ class PIXIE(object):
|
|
508 |
Returns:
|
509 |
predictions: smplx predictions
|
510 |
"""
|
511 |
-
if "jaw_pose" in param_dict.keys() and len(
|
512 |
-
param_dict["jaw_pose"].shape) == 2:
|
513 |
self.convert_pose(param_dict, param_type)
|
514 |
elif param_dict["right_wrist_pose"].shape[-1] == 6:
|
515 |
self.convert_pose(param_dict, param_type)
|
@@ -532,9 +501,8 @@ class PIXIE(object):
|
|
532 |
# change absolute head&hand pose to relative pose according to rest body pose
|
533 |
if param_type == "head" or param_type == "body":
|
534 |
param_dict["body_pose"] = self.smplx.pose_abs2rel(
|
535 |
-
param_dict["global_pose"],
|
536 |
-
|
537 |
-
abs_joint="head")
|
538 |
if param_type == "hand" or param_type == "body":
|
539 |
param_dict["body_pose"] = self.smplx.pose_abs2rel(
|
540 |
param_dict["global_pose"],
|
@@ -550,7 +518,7 @@ class PIXIE(object):
|
|
550 |
if self.cfg.model.check_pose:
|
551 |
# check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
|
552 |
# xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
|
553 |
-
for pose_ind in [14]:
|
554 |
curr_pose = param_dict["body_pose"][:, pose_ind]
|
555 |
euler_pose = converter._compute_euler_from_matrix(curr_pose)
|
556 |
for i, max_angle in enumerate([20, 70, 10]):
|
@@ -560,9 +528,7 @@ class PIXIE(object):
|
|
560 |
min=-max_angle * np.pi / 180,
|
561 |
max=max_angle * np.pi / 180,
|
562 |
)] = 0.0
|
563 |
-
param_dict[
|
564 |
-
"body_pose"][:, pose_ind] = converter.batch_euler2matrix(
|
565 |
-
euler_pose)
|
566 |
|
567 |
# SMPLX
|
568 |
verts, landmarks, joints = self.smplx(
|
@@ -594,8 +560,8 @@ class PIXIE(object):
|
|
594 |
|
595 |
# change the order of face keypoints, to be the same as "standard" 68 keypoints
|
596 |
prediction["face_kpt"] = torch.cat(
|
597 |
-
[prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]],
|
598 |
-
|
599 |
|
600 |
prediction.update(param_dict)
|
601 |
|
|
|
33 |
|
34 |
|
35 |
class PIXIE(object):
|
|
|
36 |
def __init__(self, config=None, device="cuda:0"):
|
37 |
if config is None:
|
38 |
self.cfg = cfg
|
|
|
44 |
self.param_list_dict = {}
|
45 |
for lst in self.cfg.params.keys():
|
46 |
param_list = cfg.params.get(lst)
|
47 |
+
self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list}
|
|
|
|
|
|
|
48 |
|
49 |
# Build the models
|
50 |
self._create_model()
|
|
|
93 |
self.Regressor = {}
|
94 |
for key in self.cfg.network.regressor.keys():
|
95 |
n_output = sum(self.param_list_dict[f"{key}_list"].values())
|
96 |
+
channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output])
|
|
|
97 |
if self.cfg.network.regressor.get(key).type == "mlp":
|
98 |
self.Regressor[key] = MLP(channels=channels).to(self.device)
|
99 |
+
self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict()
|
|
|
100 |
|
101 |
# Build the extractors
|
102 |
# to extract separate head/left hand/right hand feature from body feature
|
103 |
self.Extractor = {}
|
104 |
for key in self.cfg.network.extractor.keys():
|
105 |
+
channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048]
|
|
|
|
|
106 |
if self.cfg.network.extractor.get(key).type == "mlp":
|
107 |
self.Extractor[key] = MLP(channels=channels).to(self.device)
|
108 |
+
self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict()
|
|
|
109 |
|
110 |
# Build the moderators
|
111 |
self.Moderator = {}
|
|
|
113 |
share_part = key.split("_")[0]
|
114 |
detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
|
115 |
detach_feature = self.cfg.network.moderator.get(key).detach_feature
|
116 |
+
channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2]
|
|
|
117 |
self.Moderator[key] = TempSoftmaxFusion(
|
118 |
detach_inputs=detach_inputs,
|
119 |
detach_feature=detach_feature,
|
120 |
channels=channels,
|
121 |
).to(self.device)
|
122 |
+
self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict()
|
|
|
123 |
|
124 |
# Build the SMPL-X body model, which we also use to represent faces and
|
125 |
# hands, using the relevant parts only
|
|
|
136 |
print(f"pixie trained model path: {model_path} does not exist!")
|
137 |
exit()
|
138 |
# eval mode
|
139 |
+
for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]:
|
|
|
|
|
140 |
for net in module.values():
|
141 |
net.eval()
|
142 |
|
|
|
172 |
# crop
|
173 |
cropper_key = "hand" if "hand" in part_key else part_key
|
174 |
points_scale = image.shape[-2:]
|
175 |
+
cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale)
|
|
|
176 |
# transform points(must be normalized to [-1.1]) accordingly
|
177 |
cropped_points_dict = {}
|
178 |
for points_key in points_dict.keys():
|
179 |
points = points_dict[points_key]
|
180 |
cropped_points = self.Cropper[cropper_key].transform_points(
|
181 |
+
points, tform, points_scale, normalize=True
|
182 |
+
)
|
183 |
cropped_points_dict[points_key] = cropped_points
|
184 |
return cropped_image, cropped_points_dict
|
185 |
|
|
|
231 |
# then predict share parameters
|
232 |
feature[key][f"{key}_share"] = feature[key][key]
|
233 |
share_dict = self.decompose_code(
|
234 |
+
self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]),
|
|
|
235 |
self.param_list_dict[f"{part}_share_list"],
|
236 |
)
|
237 |
# compose parameters
|
|
|
243 |
f_body = feature["body"]["body"]
|
244 |
# extract part feature
|
245 |
for part_name in ["head", "left_hand", "right_hand"]:
|
246 |
+
feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"](
|
247 |
+
f_body
|
248 |
+
)
|
249 |
|
250 |
# -- check if part crops are given, if not, crop parts by coarse body estimation
|
251 |
+
if (
|
252 |
+
"head_image" not in data[key].keys() or
|
253 |
+
"left_hand_image" not in data[key].keys() or
|
254 |
+
"right_hand_image" not in data[key].keys()
|
255 |
+
):
|
256 |
# - run without fusion to get coarse estimation, for cropping parts
|
257 |
# body only
|
258 |
body_dict = self.decompose_code(
|
|
|
261 |
)
|
262 |
# head share
|
263 |
head_share_dict = self.decompose_code(
|
264 |
+
self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
|
|
|
265 |
self.param_list_dict["head" + "_share_list"],
|
266 |
)
|
267 |
# right hand share
|
268 |
right_hand_share_dict = self.decompose_code(
|
269 |
+
self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
|
|
|
270 |
self.param_list_dict["hand" + "_share_list"],
|
271 |
)
|
272 |
# left hand share
|
273 |
left_hand_share_dict = self.decompose_code(
|
274 |
+
self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
|
|
|
275 |
self.param_list_dict["hand" + "_share_list"],
|
276 |
)
|
277 |
# change the dict name from right to left
|
278 |
+
left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop(
|
279 |
+
"right_hand_pose"
|
280 |
+
)
|
281 |
+
left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
|
282 |
+
"right_wrist_pose"
|
283 |
+
)
|
284 |
param_dict[key] = {
|
285 |
**body_dict,
|
286 |
**head_share_dict,
|
|
|
290 |
if body_only:
|
291 |
param_dict["moderator_weight"] = None
|
292 |
return param_dict
|
293 |
+
prediction_body_only = self.decode(param_dict[key], param_type="body")
|
|
|
294 |
# crop
|
295 |
for part_name in ["head", "left_hand", "right_hand"]:
|
296 |
part = part_name.split("_")[-1]
|
297 |
points_dict = {
|
298 |
+
"smplx_kpt": prediction_body_only["smplx_kpt"],
|
299 |
+
"trans_verts": prediction_body_only["transformed_vertices"],
|
|
|
|
|
300 |
}
|
301 |
+
image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"])
|
|
|
302 |
cropped_image, cropped_joints_dict = self.part_from_body(
|
303 |
+
image_hd, part_name, points_dict
|
304 |
+
)
|
305 |
data[key][part_name + "_image"] = cropped_image
|
306 |
|
307 |
# -- encode features from part crops, then fuse feature using the weight from moderator
|
|
|
321 |
self.Regressor[f"{part}_share"](f_part),
|
322 |
self.param_list_dict[f"{part}_share_list"],
|
323 |
)
|
324 |
+
param_dict["body_" + part_name] = {**part_dict, **part_share_dict}
|
|
|
|
|
|
|
325 |
|
326 |
# moderator to assign weight, then integrate features
|
327 |
+
f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"](
|
328 |
+
feature["body"][f"{part_name}_share"], f_part, work=True
|
329 |
+
)
|
|
|
330 |
if copy_and_paste:
|
331 |
# copy and paste strategy always trusts the results from part
|
332 |
feature["body"][f"{part_name}_share"] = f_part
|
|
|
334 |
# for hand, if part weight > 0.7 (very confident, then fully trust part)
|
335 |
part_w = f_weight[:, [1]]
|
336 |
part_w[part_w > 0.7] = 1.0
|
337 |
+
f_body_out = (
|
338 |
+
feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w
|
339 |
+
)
|
340 |
feature["body"][f"{part_name}_share"] = f_body_out
|
341 |
else:
|
342 |
feature["body"][f"{part_name}_share"] = f_body_out
|
|
|
347 |
# -- predict parameters from fused body feature
|
348 |
# head share
|
349 |
head_share_dict = self.decompose_code(
|
350 |
+
self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
|
|
|
351 |
self.param_list_dict["head" + "_share_list"],
|
352 |
)
|
353 |
# right hand share
|
354 |
right_hand_share_dict = self.decompose_code(
|
355 |
+
self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
|
|
|
356 |
self.param_list_dict["hand" + "_share_list"],
|
357 |
)
|
358 |
# left hand share
|
359 |
left_hand_share_dict = self.decompose_code(
|
360 |
+
self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
|
|
|
361 |
self.param_list_dict["hand" + "_share_list"],
|
362 |
)
|
363 |
# change the dict name from right to left
|
364 |
+
left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose")
|
365 |
+
left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
|
366 |
+
"right_wrist_pose"
|
367 |
+
)
|
|
|
|
|
368 |
param_dict["body"] = {
|
369 |
**body_dict,
|
370 |
**head_share_dict,
|
|
|
378 |
if keep_local:
|
379 |
# for local change that will not affect whole body and produce unnatral pose, trust part
|
380 |
param_dict[key]["exp"] = param_dict["body_head"]["exp"]
|
381 |
+
param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][
|
382 |
+
"right_hand_pose"]
|
383 |
+
param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][
|
384 |
+
"right_hand_pose"]
|
385 |
|
386 |
return param_dict
|
387 |
|
|
|
401 |
if "pose" in key and "jaw" not in key:
|
402 |
param_dict[key] = converter.batch_cont2matrix(param_dict[key])
|
403 |
if param_type == "body" or param_type == "head":
|
404 |
+
param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"]
|
405 |
+
)[:, None, :, :]
|
406 |
|
407 |
# complement params if it's not in given param dict
|
408 |
if param_type == "head":
|
409 |
batch_size = param_dict["shape"].shape[0]
|
410 |
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
|
411 |
param_dict["global_pose"] = param_dict["head_pose"]
|
412 |
+
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
|
413 |
+
batch_size, -1, -1, -1
|
414 |
+
)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
|
|
|
415 |
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
416 |
+
batch_size, -1, -1, -1
|
417 |
+
)
|
418 |
+
param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
419 |
+
batch_size, -1, -1, -1
|
420 |
+
)
|
421 |
+
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
|
422 |
+
batch_size, -1, -1, -1
|
423 |
+
)
|
424 |
+
param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
425 |
+
batch_size, -1, -1, -1
|
426 |
+
)
|
427 |
+
param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand(
|
428 |
+
batch_size, -1, -1, -1
|
429 |
+
)
|
430 |
elif param_type == "hand":
|
431 |
batch_size = param_dict["right_hand_pose"].shape[0]
|
432 |
+
param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
|
|
|
433 |
dtype = param_dict["right_hand_pose"].dtype
|
434 |
device = param_dict["right_hand_pose"].device
|
435 |
+
x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1))
|
|
|
|
|
436 |
x_180_pose[0, 2, 2] = -1.0
|
437 |
x_180_pose[0, 1, 1] = -1.0
|
438 |
+
param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
439 |
+
param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1)
|
440 |
+
param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1)
|
|
|
|
|
|
|
441 |
param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
|
442 |
+
batch_size, -1, -1, -1
|
443 |
+
)
|
444 |
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
445 |
+
batch_size, -1, -1, -1
|
446 |
+
)
|
447 |
+
param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
|
448 |
+
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
|
449 |
+
batch_size, -1, -1, -1
|
450 |
+
)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
|
451 |
+
param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
|
452 |
+
batch_size, -1, -1, -1
|
453 |
+
)
|
454 |
+
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
|
455 |
+
batch_size, -1, -1, -1
|
456 |
+
)
|
457 |
elif param_type == "body":
|
458 |
# the predcition from the head and hand share regressor is always absolute pose
|
459 |
batch_size = param_dict["shape"].shape[0]
|
460 |
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
|
461 |
+
param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
|
462 |
+
param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone()
|
|
|
|
|
463 |
# the body-hand share regressor is working for right hand
|
464 |
# so we assume body network get the flipped feature for the left hand. then get the parameters
|
465 |
# then we need to flip it back to left, which matches the input left hand
|
466 |
+
param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"])
|
467 |
+
param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"])
|
|
|
|
|
468 |
else:
|
469 |
exit()
|
470 |
|
|
|
478 |
Returns:
|
479 |
predictions: smplx predictions
|
480 |
"""
|
481 |
+
if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2:
|
|
|
482 |
self.convert_pose(param_dict, param_type)
|
483 |
elif param_dict["right_wrist_pose"].shape[-1] == 6:
|
484 |
self.convert_pose(param_dict, param_type)
|
|
|
501 |
# change absolute head&hand pose to relative pose according to rest body pose
|
502 |
if param_type == "head" or param_type == "body":
|
503 |
param_dict["body_pose"] = self.smplx.pose_abs2rel(
|
504 |
+
param_dict["global_pose"], param_dict["body_pose"], abs_joint="head"
|
505 |
+
)
|
|
|
506 |
if param_type == "hand" or param_type == "body":
|
507 |
param_dict["body_pose"] = self.smplx.pose_abs2rel(
|
508 |
param_dict["global_pose"],
|
|
|
518 |
if self.cfg.model.check_pose:
|
519 |
# check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
|
520 |
# xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
|
521 |
+
for pose_ind in [14]: # head [15-1, 20-1, 21-1]:
|
522 |
curr_pose = param_dict["body_pose"][:, pose_ind]
|
523 |
euler_pose = converter._compute_euler_from_matrix(curr_pose)
|
524 |
for i, max_angle in enumerate([20, 70, 10]):
|
|
|
528 |
min=-max_angle * np.pi / 180,
|
529 |
max=max_angle * np.pi / 180,
|
530 |
)] = 0.0
|
531 |
+
param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose)
|
|
|
|
|
532 |
|
533 |
# SMPLX
|
534 |
verts, landmarks, joints = self.smplx(
|
|
|
560 |
|
561 |
# change the order of face keypoints, to be the same as "standard" 68 keypoints
|
562 |
prediction["face_kpt"] = torch.cat(
|
563 |
+
[prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], dim=1
|
564 |
+
)
|
565 |
|
566 |
prediction.update(param_dict)
|
567 |
|
lib/pixielib/utils/array_cropper.py
CHANGED
@@ -23,15 +23,14 @@ def points2bbox(points, points_scale=None):
|
|
23 |
bottom = np.max(points[:, 1])
|
24 |
size = max(right - left, bottom - top)
|
25 |
# + old_size*0.1])
|
26 |
-
center = np.array(
|
27 |
-
[right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
|
28 |
return center, size
|
29 |
# translate center
|
30 |
|
31 |
|
32 |
def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0):
|
33 |
trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale
|
34 |
-
center = center + trans_scale * bbox_size
|
35 |
scale = np.random.rand() * (scale[1] - scale[0]) + scale[0]
|
36 |
size = int(bbox_size * scale)
|
37 |
return center, size
|
@@ -48,27 +47,25 @@ def crop_array(image, center, bboxsize, crop_size):
|
|
48 |
tform: 3x3 affine matrix
|
49 |
"""
|
50 |
# points: top-left, top-right, bottom-right
|
51 |
-
src_pts = np.array(
|
52 |
-
[
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
# estimate transformation between points
|
60 |
tform = estimate_transform("similarity", src_pts, DST_PTS)
|
61 |
|
62 |
# warp images
|
63 |
-
cropped_image = warp(image,
|
64 |
-
tform.inverse,
|
65 |
-
output_shape=(crop_size, crop_size))
|
66 |
|
67 |
return cropped_image, tform.params.T
|
68 |
|
69 |
|
70 |
class Cropper(object):
|
71 |
-
|
72 |
def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0):
|
73 |
self.crop_size = crop_size
|
74 |
self.scale = scale
|
@@ -78,11 +75,9 @@ class Cropper(object):
|
|
78 |
# points to bbox
|
79 |
center, bbox_size = points2bbox(points, points_scale)
|
80 |
# argument bbox.
|
81 |
-
center, bbox_size = augment_bbox(
|
82 |
-
|
83 |
-
|
84 |
-
trans_scale=self.trans_scale)
|
85 |
# crop
|
86 |
-
cropped_image, tform = crop_array(image, center, bbox_size,
|
87 |
-
self.crop_size)
|
88 |
return cropped_image, tform
|
|
|
23 |
bottom = np.max(points[:, 1])
|
24 |
size = max(right - left, bottom - top)
|
25 |
# + old_size*0.1])
|
26 |
+
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
|
|
|
27 |
return center, size
|
28 |
# translate center
|
29 |
|
30 |
|
31 |
def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0):
|
32 |
trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale
|
33 |
+
center = center + trans_scale * bbox_size # 0.5
|
34 |
scale = np.random.rand() * (scale[1] - scale[0]) + scale[0]
|
35 |
size = int(bbox_size * scale)
|
36 |
return center, size
|
|
|
47 |
tform: 3x3 affine matrix
|
48 |
"""
|
49 |
# points: top-left, top-right, bottom-right
|
50 |
+
src_pts = np.array(
|
51 |
+
[
|
52 |
+
[center[0] - bboxsize / 2, center[1] - bboxsize / 2],
|
53 |
+
[center[0] + bboxsize / 2, center[1] - bboxsize / 2],
|
54 |
+
[center[0] + bboxsize / 2, center[1] + bboxsize / 2],
|
55 |
+
]
|
56 |
+
)
|
57 |
+
DST_PTS = np.array([[0, 0], [crop_size - 1, 0], [crop_size - 1, crop_size - 1]])
|
58 |
|
59 |
# estimate transformation between points
|
60 |
tform = estimate_transform("similarity", src_pts, DST_PTS)
|
61 |
|
62 |
# warp images
|
63 |
+
cropped_image = warp(image, tform.inverse, output_shape=(crop_size, crop_size))
|
|
|
|
|
64 |
|
65 |
return cropped_image, tform.params.T
|
66 |
|
67 |
|
68 |
class Cropper(object):
|
|
|
69 |
def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0):
|
70 |
self.crop_size = crop_size
|
71 |
self.scale = scale
|
|
|
75 |
# points to bbox
|
76 |
center, bbox_size = points2bbox(points, points_scale)
|
77 |
# argument bbox.
|
78 |
+
center, bbox_size = augment_bbox(
|
79 |
+
center, bbox_size, scale=self.scale, trans_scale=self.trans_scale
|
80 |
+
)
|
|
|
81 |
# crop
|
82 |
+
cropped_image, tform = crop_array(image, center, bbox_size, self.crop_size)
|
|
|
83 |
return cropped_image, tform
|