Spaces:
Running
Running
import torch | |
import numpy as np | |
from fairseq import utils,tasks | |
from utils.checkpoint_utils import load_model_ensemble_and_task | |
from utils.eval_utils import eval_step | |
from tasks.refcoco import RefcocoTask | |
from models.polyformer import PolyFormerModel | |
from PIL import Image | |
import cv2 | |
import math | |
from skimage import draw | |
tasks.register_task('refcoco', RefcocoTask) | |
# turn on cuda if GPU is available | |
use_cuda = torch.cuda.is_available() | |
# use fp16 only when GPU is available | |
use_fp16 = True | |
# Load pretrained ckpt & config | |
overrides={"bpe_dir":"utils/BPE"} | |
models, cfg, task = load_model_ensemble_and_task( | |
utils.split_paths('weights/polyformer_l_refcocog.pt'), | |
arg_overrides=overrides | |
) | |
# print(cfg) | |
cfg.common.seed = 7 | |
cfg.generation.beam = 5 | |
cfg.generation.min_len = 12 | |
cfg.generation.max_len_a = 0 | |
cfg.generation.max_len_b = 420 | |
cfg.generation.no_repeat_ngram_size = 3 | |
# cfg.max_tgt_length = 256 | |
#cfg.num_bins = 1000 | |
cfg.task.patch_image_size = 512 | |
from bert.tokenization_bert import BertTokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Fix seed for stochastic decoding | |
if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
np.random.seed(cfg.common.seed) | |
utils.set_torch_seed(cfg.common.seed) | |
# model = '' | |
# Move models to GPU | |
for model in models: | |
model.eval() | |
if use_fp16: | |
model.half() | |
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
model.cuda() | |
model.prepare_for_inference_(cfg) | |
# Initialize generator | |
generator = task.build_generator(models, cfg.generation) | |
# Image transform | |
from torchvision import transforms | |
mean = [0.5, 0.5, 0.5] | |
std = [0.5, 0.5, 0.5] | |
patch_resize_transform = transforms.Compose([ | |
lambda image: image.convert("RGB"), | |
transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
]) | |
# Text preprocess | |
bos_item = torch.LongTensor([task.src_dict.bos()]) | |
eos_item = torch.LongTensor([task.src_dict.eos()]) | |
pad_idx = task.src_dict.pad() | |
# Construct input for refcoco task | |
patch_image_size = cfg.task.patch_image_size | |
def construct_sample(image: Image, text: str): | |
w, h = image.size | |
w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0) | |
h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0) | |
patch_image = patch_resize_transform(image).unsqueeze(0) | |
patch_mask = torch.tensor([True]) | |
prompt = ' which region does the text " {} " describe?'.format(text) | |
tokenized = tokenizer.batch_encode_plus([prompt], padding="longest", return_tensors="pt") | |
src_tokens = tokenized["input_ids"] | |
att_masks = tokenized["attention_mask"] | |
src_lengths = torch.LongTensor(att_masks.ne(0).long().sum()) | |
sample = { | |
"id":np.array(['42']), | |
"net_input": { | |
"src_tokens": src_tokens, | |
"src_lengths": src_lengths, | |
"att_masks": att_masks, | |
"patch_images": patch_image, | |
"patch_masks": patch_mask, | |
}, | |
"w_resize_ratios": w_resize_ratio, | |
"h_resize_ratios": h_resize_ratio, | |
"region_coords": torch.randn(1, 4), | |
"label": np.zeros((512,512)), | |
"poly": 'None', | |
"text": text | |
} | |
return sample | |
# Function to turn FP32 to FP16 | |
def apply_half(t): | |
if t.dtype is torch.float32: | |
return t.to(dtype=torch.half) | |
return t | |
from io import BytesIO | |
import base64 | |
import re | |
def pre_caption(caption): | |
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person') | |
caption = re.sub( | |
r"\s{2,}", | |
' ', | |
caption, | |
) | |
caption = caption.rstrip('\n') | |
caption = caption.strip(' ') | |
return caption | |
def convert_pts(coeffs): | |
pts = [] | |
for i in range(len(coeffs) // 2): | |
pts.append([coeffs[2 * i + 1], coeffs[2 * i]]) # y, x | |
return np.array(pts, np.int32) | |
def get_mask_from_codes(codes, img_size): | |
masks = [np.zeros(img_size)] | |
for code in codes: | |
mask = draw.polygon2mask(img_size, convert_pts(code)) | |
mask = np.array(mask, np.uint8) | |
masks.append(mask) | |
mask = sum(masks) | |
mask = mask > 0 | |
return mask.astype(np.uint8) | |
def overlay_predictions(img, mask=None, polygons=None, bbox=None, color_box=(0, 255, 0), color_mask=[255, 102, 102], color_poly=[255, 0, 0], thickness=3, radius=6): | |
overlayed = img.copy() | |
if bbox is not None: | |
overlayed = draw_bbox(overlayed, bbox, color=color_box, thickness=thickness) | |
if mask is not None: | |
overlayed = overlay_davis(overlayed, mask, colors=[[0, 0, 0], color_mask]) | |
if polygons is not None: | |
overlayed = plot_polygons(overlayed, polygons, color=color_poly, radius=radius) | |
return overlayed | |
def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 102, 102]], cscale=1, alpha=0.4): # [255, 178, 102] orange [102, 178, 255] red | |
from scipy.ndimage.morphology import binary_dilation | |
colors = np.reshape(colors, (-1, 3)) | |
colors = np.atleast_2d(colors) * cscale | |
im_overlay = image.copy() | |
object_ids = np.unique(mask) | |
h_i, w_i = image.shape[0:2] | |
h_m, w_m = mask.shape[0:2] | |
if h_i != h_m: | |
mask = cv2.resize(mask, [h_i, w_i], interpolation=cv2.INTER_NEAREST) | |
for object_id in object_ids[1:]: | |
# Overlay color on binary mask | |
foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) | |
binary_mask = mask == object_id | |
# Compose image | |
im_overlay[binary_mask] = foreground[binary_mask] | |
return im_overlay.astype(image.dtype) | |
def draw_bbox(img, box, color=(0, 255, 0), thickness=2): | |
x1, y1, x2, y2 = box | |
return cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=thickness) | |
def plot_polygons(img, polygons, color=(255, 0, 0), radius=7): | |
for polygon in polygons: | |
if len(polygon) > 0: | |
polygon = np.reshape(polygon[:len(polygon)-len(polygon)%2], (len(polygon)//2, 2)).astype(np.int16) | |
for i, point in enumerate(polygon): | |
img = cv2.circle(img, point, radius, color, thickness=-1) | |
img = cv2.circle(img, polygon[0], radius, color, thickness=-1) | |
return img | |
def plot_arrow(img, polygons, color=(128, 128, 128), thickness=3, tip_length=0.3): | |
for polygon in polygons: | |
if len(polygon) > 0: | |
polygon = np.reshape(polygon[:len(polygon)-len(polygon)%2], (len(polygon)//2, 2)).astype(np.int16) | |
for i, point in enumerate(polygon): | |
if i > 0: | |
img = cv2.arrowedLine(img, polygon[i-1], point, color, thickness=thickness, tipLength=tip_length) | |
return img | |
def downsample_polygon(polygon, ds_rate=25): | |
points = np.array(polygon).reshape(int(len(polygon) / 2), 2) | |
points = points[::ds_rate] | |
return list(points.flatten()) | |
def downsample_polygons(polygons, ds_rate=25): | |
polygons_ds = [] | |
for polygon in polygons: | |
polygons_ds.append(downsample_polygon(polygon, ds_rate)) | |
return polygons_ds | |
def visual_grounding(image, text): | |
# Construct input sample & preprocess for GPU if cuda available | |
sample = construct_sample(image, text.lower()) | |
sample = utils.move_to_cuda(sample) if use_cuda else sample | |
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample | |
with torch.no_grad(): | |
if isinstance(models, list): | |
model = models[0] | |
min_len = 6 | |
max_len = 210 | |
model.eval() | |
img = sample["net_input"]["patch_images"] | |
b = img.shape[0] | |
prev_output_token_11 = [[0] for _ in range(b)] | |
prev_output_token_12 = [[0] for _ in range(b)] | |
prev_output_token_21 = [[0] for _ in range(b)] | |
prev_output_token_22 = [[0] for _ in range(b)] | |
delta_x1 = [[0] for _ in range(b)] | |
delta_y1 = [[0] for _ in range(b)] | |
delta_x2 = [[1] for _ in range(b)] | |
delta_y2 = [[1] for _ in range(b)] | |
gen_out = [[] for _ in range(b)] | |
n_bins = 64 | |
unfinish_flag = np.ones(b) | |
i = 0 | |
encoder_out = model.encoder( | |
sample['net_input']['src_tokens'], | |
src_lengths=sample['net_input']['src_lengths'], | |
att_masks=sample['net_input']['att_masks'], | |
patch_images=sample['net_input']['patch_images'], | |
patch_masks=sample['net_input']['patch_masks'], | |
token_embeddings=None, | |
return_all_hiddens=False, | |
sample_patch_num=None | |
) | |
attn_masks = [] | |
while i < max_len and unfinish_flag.any(): | |
# print(i) | |
prev_output_tokens_11_tensor = torch.tensor(np.array(prev_output_token_11)).to(img.device).long() | |
prev_output_tokens_12_tensor = torch.tensor(np.array(prev_output_token_12)).to(img.device).long() | |
prev_output_tokens_21_tensor = torch.tensor(np.array(prev_output_token_21)).to(img.device).long() | |
prev_output_tokens_22_tensor = torch.tensor(np.array(prev_output_token_22)).to(img.device).long() | |
delta_x1_tensor = torch.tensor(np.array(delta_x1)).to(img.device) | |
delta_x2_tensor = torch.tensor(np.array(delta_x2)).to(img.device) | |
delta_y1_tensor = torch.tensor(np.array(delta_y1)).to(img.device) | |
delta_y2_tensor = torch.tensor(np.array(delta_y2)).to(img.device) | |
net_output = model.decoder( | |
prev_output_tokens_11_tensor, | |
prev_output_tokens_12_tensor, | |
prev_output_tokens_21_tensor, | |
prev_output_tokens_22_tensor, | |
delta_x1_tensor, | |
delta_y1_tensor, | |
delta_x2_tensor, | |
delta_y2_tensor, | |
code_masks=None, | |
encoder_out=encoder_out, | |
features_only=False, | |
alignment_layer=None, | |
alignment_heads=None, | |
src_lengths=sample['net_input']['src_lengths'], | |
return_all_hiddens=False | |
) | |
cls_output = net_output[0] | |
cls_type = torch.argmax(cls_output, 2) | |
reg_output = net_output[1].squeeze(-1) | |
attn = net_output[2]['attn'] | |
attn_arrays = [att.detach().cpu().numpy() for att in attn] | |
attn_arrays = np.concatenate(attn_arrays, 0) | |
attn_arrays = np.mean(attn_arrays, 0) | |
attn_arrays = attn_arrays[i, :256].reshape(16, 16) | |
h, w = image.size | |
attn_mask = cv2.resize(attn_arrays.astype(np.float32), (h, w)) | |
attn_masks.append(attn_mask) | |
for j in range(b): | |
# print(j) | |
if unfinish_flag[j] == 1: # prediction is not finished | |
cls_j = cls_type[j, i].item() | |
if cls_j == 0 or (cls_j == 2 and i < min_len): # 0 for coordinate tokens; 2 for eos | |
output_j_x, output_j_y = reg_output[j, i].cpu().numpy() | |
output_j_x = min(output_j_x, 1) | |
output_j_y = min(output_j_y, 1) | |
gen_out[j].extend([output_j_x, output_j_y]) | |
output_j_x = output_j_x * (n_bins - 1) | |
output_j_y = output_j_y * (n_bins - 1) | |
output_j_x_floor = math.floor(output_j_x) | |
output_j_y_floor = math.floor(output_j_y) | |
output_j_x_ceil = math.ceil(output_j_x) | |
output_j_y_ceil = math.ceil(output_j_y) | |
# convert to token | |
prev_output_token_11[j].append(output_j_x_floor * n_bins + output_j_y_floor + 4) | |
prev_output_token_12[j].append(output_j_x_floor * n_bins + output_j_y_ceil + 4) | |
prev_output_token_21[j].append(output_j_x_ceil * n_bins + output_j_y_floor + 4) | |
prev_output_token_22[j].append(output_j_x_ceil * n_bins + output_j_y_ceil + 4) | |
delta_x = output_j_x - output_j_x_floor | |
delta_y = output_j_y - output_j_y_floor | |
elif cls_j == 1: # 1 for separator tokens | |
gen_out[j].append(2) # insert 2 indicating separator tokens | |
prev_output_token_11[j].append(3) | |
prev_output_token_12[j].append(3) | |
prev_output_token_21[j].append(3) | |
prev_output_token_22[j].append(3) | |
delta_x = 0 | |
delta_y = 0 | |
else: # eos is predicted and i >= min_len | |
unfinish_flag[j] = 0 | |
gen_out[j].append(-1) | |
prev_output_token_11[j].append(2) # 2 is eos token | |
prev_output_token_12[j].append(2) # 2 is eos token | |
prev_output_token_21[j].append(2) # 2 is eos token | |
prev_output_token_22[j].append(2) # 2 is eos token | |
delta_x = 0 | |
delta_y = 0 | |
else: # prediction is finished | |
gen_out[j].append(-1) | |
prev_output_token_11[j].append(1) # 1 is padding token | |
prev_output_token_12[j].append(1) | |
prev_output_token_21[j].append(1) | |
prev_output_token_22[j].append(1) | |
delta_x = 0 | |
delta_y = 0 | |
delta_x1[j].append(delta_x) | |
delta_y1[j].append(delta_y) | |
delta_x2[j].append(1 - delta_x) | |
delta_y2[j].append(1 - delta_y) | |
i += 1 | |
print("inference step: ", i) | |
hyps = [] | |
hyps_det = [] | |
n_poly_pred = [] | |
b = len(gen_out) | |
for i in range(b): | |
gen_out_i = np.array(gen_out[i]) | |
gen_out_i = gen_out_i[gen_out_i != -1] # excluding eos and padding indices | |
gen_out_i_det = gen_out_i[:4] | |
w, h = image.size | |
gen_out_i_det[::2] *= w | |
gen_out_i_det[1::2] *= h | |
polygons_pred = gen_out_i[4:] | |
polygons_pred = np.append(polygons_pred, [2]) | |
size = len(polygons_pred) | |
idx_list = [idx for idx, val in | |
enumerate(polygons_pred) if val == 2] # 2 indicates separator token | |
polygons_pred[::2] *= w | |
polygons_pred[1::2] *= h | |
if len(idx_list) > 0: # multiple polygons | |
polygons = [] | |
pred_idx = 0 | |
for idx in idx_list: | |
cur_idx = idx | |
if pred_idx == cur_idx or pred_idx == size: | |
pass | |
else: | |
polygons.append(polygons_pred[pred_idx: cur_idx]) | |
pred_idx = cur_idx + 1 | |
else: | |
polygons = [polygons_pred] | |
n_poly_pred.append(len(polygons)) | |
hyps.append(polygons) | |
hyps_det.append(gen_out_i_det) | |
pred_mask = get_mask_from_codes(hyps[0], (h, w)) | |
pred_overlayed = overlay_predictions(np.asarray(image), pred_mask, hyps[0], hyps_det[0]) | |
return pred_overlayed, np.array(pred_mask*255, dtype=np.uint8) | |