zhaoyian01's picture
Add application file
6d1366a
raw
history blame
14.3 kB
import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
from .base import BasePredictor
class BRSBasePredictor(BasePredictor):
def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
super().__init__(model, device, **kwargs)
self.optimize_after_n_clicks = optimize_after_n_clicks
self.opt_functor = opt_functor
self.opt_data = None
self.input_data = None
def set_input_image(self, image):
super().set_input_image(image)
self.opt_data = None
self.input_data = None
def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
for list_indx, clicks_list in enumerate(clicks_lists):
for click in clicks_list:
y, x = click.coords
y, x = int(round(y)), int(round(x))
y1, x1 = y - radius, x - radius
y2, x2 = y + radius + 1, x + radius + 1
if click.is_positive:
pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
else:
neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
with torch.no_grad():
pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
return pos_clicks_map, neg_clicks_map
def get_states(self):
return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
def set_states(self, states):
self._set_transform_states(states['transform_states'])
self.opt_data = states['opt_data']
class FeatureBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.insertion_mode = insertion_mode
self._c1_features = None
if self.insertion_mode == 'after_deeplab':
self.num_channels = model.feature_extractor.ch
elif self.insertion_mode == 'after_c4':
self.num_channels = model.feature_extractor.aspp_in_channels
elif self.insertion_mode == 'after_aspp':
self.num_channels = model.feature_extractor.ch + 32
else:
raise NotImplementedError
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
self.input_data = self._get_head_input(image_nd, points_nd)
def get_prediction_logits(scale, bias):
scale = scale.view(bs, -1, 1, 1)
bias = bias.view(bs, -1, 1, 1)
if self.with_flip:
scale = scale.repeat(2, 1, 1, 1)
bias = bias.repeat(2, 1, 1, 1)
scaled_backbone_features = self.input_data * scale
scaled_backbone_features = scaled_backbone_features + bias
if self.insertion_mode == 'after_c4':
x = self.net.feature_extractor.aspp(scaled_backbone_features)
x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
align_corners=True)
x = torch.cat((x, self._c1_features), dim=1)
scaled_backbone_features = self.net.feature_extractor.head(x)
elif self.insertion_mode == 'after_aspp':
scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
pred_logits = self.net.head(scaled_backbone_features)
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
**self.opt_functor.optimizer_params)
self.opt_data = opt_result[0]
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
def _get_head_input(self, image_nd, points):
with torch.no_grad():
image_nd, prev_mask = self.net.prepare_input(image_nd)
coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
if self.net.rgb_conv is not None:
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
additional_features = None
elif hasattr(self.net, 'maps_transform'):
x = image_nd
additional_features = self.net.maps_transform(coord_features)
if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features)
c1 = self.net.feature_extractor.skip_project(c1)
if self.insertion_mode == 'after_aspp':
x = self.net.feature_extractor.aspp(c4)
x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, c1), dim=1)
backbone_features = x
else:
backbone_features = c4
self._c1_features = c1
else:
backbone_features = self.net.feature_extractor(x, additional_features)[0]
return backbone_features
class HRNetFeatureBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.insertion_mode = insertion_mode
self._c1_features = None
if self.insertion_mode == 'A':
self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
elif self.insertion_mode == 'C':
self.num_channels = 2 * model.feature_extractor.ocr_width
else:
raise NotImplementedError
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
self.input_data = self._get_head_input(image_nd, points_nd)
def get_prediction_logits(scale, bias):
scale = scale.view(bs, -1, 1, 1)
bias = bias.view(bs, -1, 1, 1)
if self.with_flip:
scale = scale.repeat(2, 1, 1, 1)
bias = bias.repeat(2, 1, 1, 1)
scaled_backbone_features = self.input_data * scale
scaled_backbone_features = scaled_backbone_features + bias
if self.insertion_mode == 'A':
if self.net.feature_extractor.ocr_width > 0:
out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
feats = self.net.feature_extractor.ocr_distri_head(feats, context)
else:
feats = scaled_backbone_features
pred_logits = self.net.feature_extractor.cls_head(feats)
elif self.insertion_mode == 'C':
pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
else:
raise NotImplementedError
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
**self.opt_functor.optimizer_params)
self.opt_data = opt_result[0]
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits
def _get_head_input(self, image_nd, points):
with torch.no_grad():
image_nd, prev_mask = self.net.prepare_input(image_nd)
coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
if self.net.rgb_conv is not None:
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
additional_features = None
elif hasattr(self.net, 'maps_transform'):
x = image_nd
additional_features = self.net.maps_transform(coord_features)
feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features)
if self.insertion_mode == 'A':
backbone_features = feats
elif self.insertion_mode == 'C':
out_aux = self.net.feature_extractor.aux_head(feats)
feats = self.net.feature_extractor.conv3x3_ocr(feats)
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
else:
raise NotImplementedError
return backbone_features
class InputBRSPredictor(BRSBasePredictor):
def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
self.optimize_target = optimize_target
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
num_clicks = len(clicks_lists[0])
if self.opt_data is None or is_image_changed:
if self.optimize_target == 'dmaps':
opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch
else:
opt_channels = 3
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
device=self.device, dtype=torch.float32)
def get_prediction_logits(opt_bias):
input_image, prev_mask = self.net.prepare_input(image_nd)
dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
if self.optimize_target == 'rgb':
input_image = input_image + opt_bias
elif self.optimize_target == 'dmaps':
if self.net.with_prev_mask:
dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
else:
dmaps = dmaps + opt_bias
if self.net.rgb_conv is not None:
x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
if self.optimize_target == 'all':
x = x + opt_bias
coord_features = None
elif hasattr(self.net, 'maps_transform'):
x = input_image
coord_features = self.net.maps_transform(dmaps)
pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances']
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
return pred_logits
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
shape=self.opt_data.shape)
if num_clicks > self.optimize_after_n_clicks:
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
**self.opt_functor.optimizer_params)
self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
with torch.no_grad():
if self.opt_functor.best_prediction is not None:
opt_pred_logits = self.opt_functor.best_prediction
else:
opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
opt_pred_logits = get_prediction_logits(*opt_vars)
return opt_pred_logits