Spaces:
Runtime error
Runtime error
С Чичерин
commited on
Commit
•
57510bb
1
Parent(s):
8029b4a
fixing issues with cpu
Browse files
test.py
CHANGED
@@ -526,10 +526,10 @@ def get_crossentropy_loss(gt,pre):
|
|
526 |
return entropy_loss
|
527 |
|
528 |
def get_alpha_loss(predict, alpha, trimap):
|
529 |
-
weighted = torch.zeros(trimap.shape).
|
530 |
weighted[trimap == 128] = 1.
|
531 |
alpha_f = alpha / 255.
|
532 |
-
alpha_f = alpha_f.
|
533 |
diff = predict - alpha_f
|
534 |
diff = diff * weighted
|
535 |
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
@@ -537,9 +537,9 @@ def get_alpha_loss(predict, alpha, trimap):
|
|
537 |
return alpha_loss_weighted
|
538 |
|
539 |
def get_alpha_loss_whole_img(predict, alpha):
|
540 |
-
weighted = torch.ones(alpha.shape).
|
541 |
alpha_f = alpha / 255.
|
542 |
-
alpha_f = alpha_f.
|
543 |
diff = predict - alpha_f
|
544 |
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
545 |
alpha_loss = alpha_loss.sum()/(weighted.sum())
|
@@ -555,7 +555,7 @@ def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
|
|
555 |
kernel = np.sum(gaussian(grid), axis=2)
|
556 |
kernel /= np.sum(kernel)
|
557 |
kernel = np.tile(kernel, (n_channels, 1, 1))
|
558 |
-
kernel = torch.FloatTensor(kernel[:, None, :, :]).
|
559 |
return Variable(kernel, requires_grad=False)
|
560 |
|
561 |
def conv_gauss(img, kernel):
|
@@ -576,10 +576,10 @@ def laplacian_pyramid(img, kernel, max_levels=5):
|
|
576 |
return pyr
|
577 |
|
578 |
def get_laplacian_loss(predict, alpha, trimap):
|
579 |
-
weighted = torch.zeros(trimap.shape).
|
580 |
weighted[trimap == 128] = 1.
|
581 |
alpha_f = alpha / 255.
|
582 |
-
alpha_f = alpha_f.
|
583 |
alpha_f = alpha_f.clone()*weighted
|
584 |
predict = predict.clone()*weighted
|
585 |
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
@@ -590,7 +590,7 @@ def get_laplacian_loss(predict, alpha, trimap):
|
|
590 |
|
591 |
def get_laplacian_loss_whole_img(predict, alpha):
|
592 |
alpha_f = alpha / 255.
|
593 |
-
alpha_f = alpha_f.
|
594 |
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
595 |
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
|
596 |
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
|
@@ -598,7 +598,7 @@ def get_laplacian_loss_whole_img(predict, alpha):
|
|
598 |
return laplacian_loss
|
599 |
|
600 |
def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
|
601 |
-
weighted = torch.ones(alpha.shape).
|
602 |
predict_3 = torch.cat((predict, predict, predict), 1)
|
603 |
comp = predict_3 * fg + (1. - predict_3) * bg
|
604 |
comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
|
@@ -781,7 +781,7 @@ def inference_img(model, img):
|
|
781 |
img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
|
782 |
# print(img.shape)
|
783 |
|
784 |
-
tensor_img = torch.from_numpy(img).permute(2, 0, 1).
|
785 |
input_t = tensor_img
|
786 |
input_t = input_t/255.0
|
787 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
@@ -839,7 +839,7 @@ def test_am2k(model):
|
|
839 |
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
840 |
|
841 |
with torch.no_grad():
|
842 |
-
|
843 |
predict = inference_img( model, img)
|
844 |
|
845 |
|
@@ -926,7 +926,7 @@ def test_p3m10k(model,dataset_choice, max_image=-1):
|
|
926 |
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
|
927 |
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
928 |
with torch.no_grad():
|
929 |
-
|
930 |
start = time.time()
|
931 |
|
932 |
|
|
|
526 |
return entropy_loss
|
527 |
|
528 |
def get_alpha_loss(predict, alpha, trimap):
|
529 |
+
weighted = torch.zeros(trimap.shape).to(device)
|
530 |
weighted[trimap == 128] = 1.
|
531 |
alpha_f = alpha / 255.
|
532 |
+
alpha_f = alpha_f.to(device)
|
533 |
diff = predict - alpha_f
|
534 |
diff = diff * weighted
|
535 |
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
|
|
537 |
return alpha_loss_weighted
|
538 |
|
539 |
def get_alpha_loss_whole_img(predict, alpha):
|
540 |
+
weighted = torch.ones(alpha.shape).to(device)
|
541 |
alpha_f = alpha / 255.
|
542 |
+
alpha_f = alpha_f.to(device)
|
543 |
diff = predict - alpha_f
|
544 |
alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
|
545 |
alpha_loss = alpha_loss.sum()/(weighted.sum())
|
|
|
555 |
kernel = np.sum(gaussian(grid), axis=2)
|
556 |
kernel /= np.sum(kernel)
|
557 |
kernel = np.tile(kernel, (n_channels, 1, 1))
|
558 |
+
kernel = torch.FloatTensor(kernel[:, None, :, :]).to(device)
|
559 |
return Variable(kernel, requires_grad=False)
|
560 |
|
561 |
def conv_gauss(img, kernel):
|
|
|
576 |
return pyr
|
577 |
|
578 |
def get_laplacian_loss(predict, alpha, trimap):
|
579 |
+
weighted = torch.zeros(trimap.shape).to(device)
|
580 |
weighted[trimap == 128] = 1.
|
581 |
alpha_f = alpha / 255.
|
582 |
+
alpha_f = alpha_f.to(device)
|
583 |
alpha_f = alpha_f.clone()*weighted
|
584 |
predict = predict.clone()*weighted
|
585 |
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
|
|
590 |
|
591 |
def get_laplacian_loss_whole_img(predict, alpha):
|
592 |
alpha_f = alpha / 255.
|
593 |
+
alpha_f = alpha_f.to(device)
|
594 |
gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
|
595 |
pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
|
596 |
pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
|
|
|
598 |
return laplacian_loss
|
599 |
|
600 |
def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
|
601 |
+
weighted = torch.ones(alpha.shape).to(device)
|
602 |
predict_3 = torch.cat((predict, predict, predict), 1)
|
603 |
comp = predict_3 * fg + (1. - predict_3) * bg
|
604 |
comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12)
|
|
|
781 |
img=cv2.copyMakeBorder(img, 8-h%8, 0, 8-w%8, 0, cv2.BORDER_REFLECT)
|
782 |
# print(img.shape)
|
783 |
|
784 |
+
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
|
785 |
input_t = tensor_img
|
786 |
input_t = input_t/255.0
|
787 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
839 |
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
840 |
|
841 |
with torch.no_grad():
|
842 |
+
# torch.cuda.empty_cache()
|
843 |
predict = inference_img( model, img)
|
844 |
|
845 |
|
|
|
926 |
trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
|
927 |
alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
|
928 |
with torch.no_grad():
|
929 |
+
# torch.cuda.empty_cache()
|
930 |
start = time.time()
|
931 |
|
932 |
|