Spaces:
Running
on
Zero
Running
on
Zero
ZhengPeng7
commited on
Commit
•
e2ce7e5
1
Parent(s):
327742a
Remove redundant part of our_ref in inference.
Browse files- app.py +1 -1
- models/baseline.py +24 -22
app.py
CHANGED
@@ -35,7 +35,7 @@ class ImagePreprocessor():
|
|
35 |
return image
|
36 |
|
37 |
|
38 |
-
model = BiRefNet().to(device)
|
39 |
state_dict = './BiRefNet_ep580.pth'
|
40 |
if os.path.exists(state_dict):
|
41 |
birefnet_dict = torch.load(state_dict, map_location=device)
|
|
|
35 |
return image
|
36 |
|
37 |
|
38 |
+
model = BiRefNet(bb_pretrained=False).to(device)
|
39 |
state_dict = './BiRefNet_ep580.pth'
|
40 |
if os.path.exists(state_dict):
|
41 |
birefnet_dict = torch.load(state_dict, map_location=device)
|
models/baseline.py
CHANGED
@@ -20,11 +20,11 @@ from models.refinement.stem_layer import StemLayer
|
|
20 |
|
21 |
|
22 |
class BiRefNet(nn.Module):
|
23 |
-
def __init__(self):
|
24 |
super(BiRefNet, self).__init__()
|
25 |
self.config = Config()
|
26 |
self.epoch = 1
|
27 |
-
self.bb = build_backbone(self.config.bb, pretrained=
|
28 |
|
29 |
channels = self.config.lateral_channels_in_collection
|
30 |
|
@@ -126,7 +126,7 @@ class BiRefNet(nn.Module):
|
|
126 |
x4 = self.squeeze_module(x4)
|
127 |
########## Decoder ##########
|
128 |
features = [x, x1, x2, x3, x4]
|
129 |
-
if self.config.out_ref:
|
130 |
features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
|
131 |
scaled_preds = self.decoder(features)
|
132 |
return scaled_preds, class_preds
|
@@ -231,7 +231,7 @@ class Decoder(nn.Module):
|
|
231 |
return torch.cat(patches_batch, dim=0)
|
232 |
|
233 |
def forward(self, features):
|
234 |
-
if self.config.out_ref:
|
235 |
outs_gdt_pred = []
|
236 |
outs_gdt_label = []
|
237 |
x, x1, x2, x3, x4, gdt_gt = features
|
@@ -249,18 +249,19 @@ class Decoder(nn.Module):
|
|
249 |
p3 = self.decoder_block3(_p3)
|
250 |
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
|
251 |
if self.config.out_ref:
|
252 |
-
# >> GT:
|
253 |
-
# m3 --dilation--> m3_dia
|
254 |
-
# G_3^gt * m3_dia --> G_3^m, which is the label of gradient
|
255 |
-
m3_dia = m3
|
256 |
-
gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
257 |
-
outs_gdt_label.append(gdt_label_main_3)
|
258 |
-
# >> Pred:
|
259 |
-
# p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
|
260 |
-
# F_3^G --sigmoid--> A_3^G
|
261 |
p3_gdt = self.gdt_convs_3(p3)
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
265 |
# >> Finally:
|
266 |
# p3 = p3 * A_3^G
|
@@ -274,14 +275,15 @@ class Decoder(nn.Module):
|
|
274 |
p2 = self.decoder_block2(_p2)
|
275 |
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
|
276 |
if self.config.out_ref:
|
277 |
-
# >> GT:
|
278 |
-
m2_dia = m2
|
279 |
-
gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
280 |
-
outs_gdt_label.append(gdt_label_main_2)
|
281 |
-
# >> Pred:
|
282 |
p2_gdt = self.gdt_convs_2(p2)
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
286 |
# >> Finally:
|
287 |
p2 = p2 * gdt_attn_2
|
|
|
20 |
|
21 |
|
22 |
class BiRefNet(nn.Module):
|
23 |
+
def __init__(self, bb_pretrained=True):
|
24 |
super(BiRefNet, self).__init__()
|
25 |
self.config = Config()
|
26 |
self.epoch = 1
|
27 |
+
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
28 |
|
29 |
channels = self.config.lateral_channels_in_collection
|
30 |
|
|
|
126 |
x4 = self.squeeze_module(x4)
|
127 |
########## Decoder ##########
|
128 |
features = [x, x1, x2, x3, x4]
|
129 |
+
if self.training and self.config.out_ref:
|
130 |
features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
|
131 |
scaled_preds = self.decoder(features)
|
132 |
return scaled_preds, class_preds
|
|
|
231 |
return torch.cat(patches_batch, dim=0)
|
232 |
|
233 |
def forward(self, features):
|
234 |
+
if self.training and self.config.out_ref:
|
235 |
outs_gdt_pred = []
|
236 |
outs_gdt_label = []
|
237 |
x, x1, x2, x3, x4, gdt_gt = features
|
|
|
249 |
p3 = self.decoder_block3(_p3)
|
250 |
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
|
251 |
if self.config.out_ref:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
p3_gdt = self.gdt_convs_3(p3)
|
253 |
+
if self.training:
|
254 |
+
# >> GT:
|
255 |
+
# m3 --dilation--> m3_dia
|
256 |
+
# G_3^gt * m3_dia --> G_3^m, which is the label of gradient
|
257 |
+
m3_dia = m3
|
258 |
+
gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
259 |
+
outs_gdt_label.append(gdt_label_main_3)
|
260 |
+
# >> Pred:
|
261 |
+
# p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
|
262 |
+
# F_3^G --sigmoid--> A_3^G
|
263 |
+
gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
|
264 |
+
outs_gdt_pred.append(gdt_pred_3)
|
265 |
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
266 |
# >> Finally:
|
267 |
# p3 = p3 * A_3^G
|
|
|
275 |
p2 = self.decoder_block2(_p2)
|
276 |
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
|
277 |
if self.config.out_ref:
|
|
|
|
|
|
|
|
|
|
|
278 |
p2_gdt = self.gdt_convs_2(p2)
|
279 |
+
if self.training:
|
280 |
+
# >> GT:
|
281 |
+
m2_dia = m2
|
282 |
+
gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
283 |
+
outs_gdt_label.append(gdt_label_main_2)
|
284 |
+
# >> Pred:
|
285 |
+
gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
|
286 |
+
outs_gdt_pred.append(gdt_pred_2)
|
287 |
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
288 |
# >> Finally:
|
289 |
p2 = p2 * gdt_attn_2
|